Commit 5f1a509f authored by myhloli's avatar myhloli
Browse files

feat: refactor device mode retrieval to use get_device utility

parent a392f445
...@@ -4,6 +4,8 @@ import click ...@@ -4,6 +4,8 @@ import click
from pathlib import Path from pathlib import Path
import torch import torch
from loguru import logger from loguru import logger
from mineru.utils.config_reader import get_device
from mineru.utils.model_utils import get_vram from mineru.utils.model_utils import get_vram
from ..version import __version__ from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
...@@ -144,11 +146,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i ...@@ -144,11 +146,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
def get_device_mode() -> str: def get_device_mode() -> str:
if device_mode is not None: if device_mode is not None:
return device_mode return device_mode
if torch.cuda.is_available(): else:
return "cuda" return get_device()
if torch.backends.mps.is_available():
return "mps"
return "cpu"
if os.getenv('MINERU_DEVICE_MODE', None) is None: if os.getenv('MINERU_DEVICE_MODE', None) is None:
os.environ['MINERU_DEVICE_MODE'] = get_device_mode() os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment