Commit bd5252d9 authored by myhloli's avatar myhloli
Browse files

fix: add conditional import for torch and torch_npu in config_reader.py

parent b398a2d2
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import json import json
import os import os
from loguru import logger from loguru import logger
try:
import torch
import torch_npu
except ImportError:
pass
# 定义配置文件名常量 # 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json') CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
...@@ -71,15 +77,12 @@ def get_device(): ...@@ -71,15 +77,12 @@ def get_device():
if device_mode is not None: if device_mode is not None:
return device_mode return device_mode
else: else:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
return "mps" return "mps"
else: else:
try: try:
import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
return "npu" return "npu"
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment