Commit 7c8fb44b authored by myhloli's avatar myhloli
Browse files

refactor: enhance device mode detection to support CUDA and MPS

parent f407079b
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import json import json
import os import os
import torch
from loguru import logger from loguru import logger
# 定义配置文件名常量 # 定义配置文件名常量
...@@ -93,8 +94,11 @@ def get_device(): ...@@ -93,8 +94,11 @@ def get_device():
if device_mode is not None: if device_mode is not None:
return device_mode return device_mode
else: else:
logger.warning(f"not found 'MINERU_DEVICE_MODE' in environment variable, use 'cpu' as default.") if torch.cuda.is_available():
return 'cpu' return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def get_table_recog_config(): def get_table_recog_config():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment