"vscode:/vscode.git/clone" did not exist on "407de2c1ec229629579cdb7f8518ba735a6604cf"
Commit 2688e3f7 authored by myhloli's avatar myhloli
Browse files

refactor: enhance main function parameters and improve device handling logic

parent c18934a3
...@@ -92,14 +92,14 @@ def doc_analyze( ...@@ -92,14 +92,14 @@ def doc_analyze(
ocr_enabled_list = [] ocr_enabled_list = []
for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list): for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
# 确定OCR设置 # 确定OCR设置
_ocr = False _ocr_enable = False
if parse_method == 'auto': if parse_method == 'auto':
if classify(pdf_bytes) == 'ocr': if classify(pdf_bytes) == 'ocr':
_ocr = True _ocr_enable = True
elif parse_method == 'ocr': elif parse_method == 'ocr':
_ocr = True _ocr_enable = True
ocr_enabled_list.append(_ocr) ocr_enabled_list.append(_ocr_enable)
_lang = lang_list[pdf_idx] _lang = lang_list[pdf_idx]
# 收集每个数据集中的页面 # 收集每个数据集中的页面
...@@ -110,7 +110,7 @@ def doc_analyze( ...@@ -110,7 +110,7 @@ def doc_analyze(
img_dict = images_list[page_idx] img_dict = images_list[page_idx]
all_pages_info.append(( all_pages_info.append((
pdf_idx, page_idx, pdf_idx, page_idx,
img_dict['img_pil'], _ocr, _lang, img_dict['img_pil'], _ocr_enable, _lang,
)) ))
# 准备批处理 # 准备批处理
......
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
import os import os
import click import click
from pathlib import Path from pathlib import Path
import torch
from loguru import logger from loguru import logger
from mineru.utils.model_utils import get_vram
from ..version import __version__ from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
...@@ -38,7 +40,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -38,7 +40,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
vlm-huggingface: More general. vlm-huggingface: More general.
vlm-sglang-engine: Faster(engine). vlm-sglang-engine: Faster(engine).
vlm-sglang-client: Faster(client). vlm-sglang-client: Faster(client).
without method specified, huggingface will be used by default.""", without method specified, pipeline will be used by default.""",
default='pipeline', default='pipeline',
) )
@click.option( @click.option(
...@@ -49,6 +51,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -49,6 +51,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
help=""" help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional. Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
Without languages specified, 'ch' will be used by default. Without languages specified, 'ch' will be used by default.
Adapted only for the case where the backend is set to "pipeline".
""", """,
default='ch', default='ch',
) )
...@@ -78,8 +81,63 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -78,8 +81,63 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
help='The ending page for PDF parsing, beginning from 0.', help='The ending page for PDF parsing, beginning from 0.',
default=None, default=None,
) )
@click.option(
'-f',
'--formula',
'formula_enable',
type=bool,
help='Enable formula parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
default=True,
)
@click.option(
'-t',
'--table',
'table_enable',
type=bool,
help='Enable table parsing. Default is True. Adapted only for the case where the backend is set to "pipeline".',
default=True,
)
@click.option(
'-d',
'--device',
'device_mode',
type=str,
help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Adapted only for the case where the backend is set to "pipeline". ',
default=None,
)
@click.option(
'-vm',
'--virtual-vram',
'virtual_vram',
type=int,
help='Device mode for model inference, e.g., "cpu", "cuda", "cuda:0", "npu", "npu:0", "mps". Default is "cpu". Adapted only for the case where the backend is set to "pipeline". ',
default=None,
)
def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram):
os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
def get_device_mode() -> str:
if device_mode is not None:
return device_mode
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
os.environ['MINERU_DEVICE_MODE'] = get_device_mode()
def get_virtual_vram_size() -> int:
if virtual_vram is not None:
return virtual_vram
if get_device_mode().startswith("cuda") or get_device_mode().startswith("npu"):
return round(get_vram(get_device_mode()))
return 1
os.environ['MINERU_VIRTUAL_VRAM_SIZE']= str(get_virtual_vram_size())
def main(input_path, output_dir, backend, lang, server_url, start_page_id, end_page_id):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
def parse_doc(path_list: list[Path]): def parse_doc(path_list: list[Path]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment