Commit 3eef1218 authored by myhloli's avatar myhloli
Browse files

refactor: improve environment variable initialization and enhance GPU memory handling

parent 44a9cf22
......@@ -110,7 +110,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
'--virtual-vram',
'virtual_vram',
type=int,
help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Default is "cpu". Adapted only for the case where the backend is set to "pipeline". ',
help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
default=None,
)
@click.option(
......@@ -127,8 +127,10 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
if os.getenv('MINERU_FORMULA_ENABLE', None) is None:
os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
if os.getenv('MINERU_TABLE_ENABLE', None) is None:
os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
def get_device_mode() -> str:
if device_mode is not None:
return device_mode
......@@ -137,7 +139,8 @@ def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_p
if torch.backends.mps.is_available():
return "mps"
return "cpu"
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
if os.getenv('MINERU_DEVICE_MODE', None) is None:
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
def get_virtual_vram_size() -> int:
if virtual_vram is not None:
......@@ -145,10 +148,11 @@ def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_p
if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
return round(get_vram(get_device_mode()))
return 1
if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
os.environ['MINERU_MODEL_SOURCE'] = model_source
if os.getenv('MINERU_BACKEND', None) is None:
os.environ['MINERU_MODEL_SOURCE'] = model_source
os.makedirs(output_dir, exist_ok=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment