"vscode:/vscode.git/clone" did not exist on "3159e60d59819ae874ea3cdbd28e02d9e6c57321"
Commit d58b24b5 authored by myhloli's avatar myhloli
Browse files

fix: add conditional imports for torch and torch_npu in model_utils.py

parent bd5252d9
......@@ -6,6 +6,12 @@ import numpy as np
from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
try:
import torch
import torch_npu
except ImportError:
pass
def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0):
......@@ -297,14 +303,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
def clean_memory(device='cuda'):
import torch
if device == 'cuda':
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
torch_npu.npu.empty_cache()
elif str(device).startswith("mps"):
......@@ -322,13 +325,10 @@ def clean_vram(device, vram_threshold=8):
def get_vram(device):
import torch
if torch.cuda.is_available() and str(device).startswith("cuda"):
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment