Commit 7c8fb44b authored by myhloli's avatar myhloli
Browse files

refactor: enhance device mode detection to support CUDA and MPS

parent f407079b
......@@ -2,6 +2,7 @@
import json
import os
import torch
from loguru import logger
# 定义配置文件名常量
......@@ -93,8 +94,11 @@ def get_device():
if device_mode is not None:
return device_mode
else:
logger.warning(f"not found 'MINERU_DEVICE_MODE' in environment variable, use 'cpu' as default.")
return 'cpu'
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def get_table_recog_config():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment