"tests/vscode:/vscode.git/clone" did not exist on "44e410843fb10508c911022a80f15276c76d9e60"
Commit 7ae4f80d authored by myhloli's avatar myhloli
Browse files

feat: enhance device detection to support NPU and improve error handling

parent 5f1a509f
import os
import time
import numpy as np
from typing import List, Tuple
import PIL.Image
import torch
from .model_init import MineruPipelineModel
......@@ -150,7 +151,7 @@ def doc_analyze(
def batch_image_analyze(
images_with_extra_info: list[(np.ndarray, bool, str)],
images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
formula_enable=None,
table_enable=None):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
......@@ -163,9 +164,15 @@ def batch_image_analyze(
device = get_device()
if str(device).startswith('npu'):
import torch_npu
if torch_npu.npu.is_available():
torch.npu.set_compile_mode(jit_compile=False)
try:
import torch_npu
if torch_npu.npu.is_available():
torch.npu.set_compile_mode(jit_compile=False)
except Exception as e:
raise RuntimeError(
"NPU is selected as device, but torch_npu is not available. "
"Please ensure that the torch_npu package is installed correctly."
) from e
if str(device).startswith('npu') or str(device).startswith('cuda'):
vram = get_vram(device)
......
......@@ -74,8 +74,15 @@ def get_device():
else:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
elif torch.backends.mps.is_available():
return "mps"
else:
try:
import torch_npu
if torch_npu.npu.is_available():
return "npu"
except Exception as e:
pass
return "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment