Unverified Commit 0c5f00fa authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2881 from myhloli/dev

parents 8aac6107 b943f04e
...@@ -47,9 +47,10 @@ def doc_analyze( ...@@ -47,9 +47,10 @@ def doc_analyze(
backend="transformers", backend="transformers",
model_path: str | None = None, model_path: str | None = None,
server_url: str | None = None, server_url: str | None = None,
**kwargs,
): ):
if predictor is None: if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url) predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time() # load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes) images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
...@@ -73,9 +74,10 @@ async def aio_doc_analyze( ...@@ -73,9 +74,10 @@ async def aio_doc_analyze(
backend="transformers", backend="transformers",
model_path: str | None = None, model_path: str | None = None,
server_url: str | None = None, server_url: str | None = None,
**kwargs,
): ):
if predictor is None: if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url) predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time() # load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes) images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
......
...@@ -9,7 +9,8 @@ from mineru.utils.model_utils import get_vram ...@@ -9,7 +9,8 @@ from mineru.utils.model_utils import get_vram
from ..version import __version__ from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command() @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.version_option(__version__, @click.version_option(__version__,
'--version', '--version',
'-v', '-v',
...@@ -137,7 +138,49 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -137,7 +138,49 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
) )
def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source): def main(
ctx,
input_path, output_dir, method, backend, lang, server_url,
start_page_id, end_page_id, formula_enable, table_enable,
device_mode, virtual_vram, model_source, **kwargs
):
# 解析额外参数
extra_kwargs = {}
i = 0
while i < len(ctx.args):
arg = ctx.args[i]
if arg.startswith('--'):
param_name = arg[2:].replace('-', '_') # 转换参数名格式
i += 1
if i < len(ctx.args) and not ctx.args[i].startswith('--'):
# 参数有值
try:
# 尝试转换为适当的类型
if ctx.args[i].lower() == 'true':
extra_kwargs[param_name] = True
elif ctx.args[i].lower() == 'false':
extra_kwargs[param_name] = False
elif '.' in ctx.args[i]:
try:
extra_kwargs[param_name] = float(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
else:
try:
extra_kwargs[param_name] = int(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
except:
extra_kwargs[param_name] = ctx.args[i]
else:
# 布尔型标志参数
extra_kwargs[param_name] = True
i -= 1
i += 1
# 将解析出的参数合并到kwargs
kwargs.update(extra_kwargs)
if not backend.endswith('-client'): if not backend.endswith('-client'):
def get_device_mode() -> str: def get_device_mode() -> str:
...@@ -184,7 +227,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i ...@@ -184,7 +227,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
table_enable=table_enable, table_enable=table_enable,
server_url=server_url, server_url=server_url,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id end_page_id=end_page_id,
**kwargs,
) )
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
......
...@@ -225,6 +225,7 @@ async def _async_process_vlm( ...@@ -225,6 +225,7 @@ async def _async_process_vlm(
f_dump_content_list, f_dump_content_list,
f_make_md_mode, f_make_md_mode,
server_url=None, server_url=None,
**kwargs,
): ):
"""异步处理VLM后端逻辑""" """异步处理VLM后端逻辑"""
parse_method = "vlm" parse_method = "vlm"
...@@ -238,7 +239,7 @@ async def _async_process_vlm( ...@@ -238,7 +239,7 @@ async def _async_process_vlm(
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = await aio_vlm_doc_analyze( middle_json, infer_result = await aio_vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
) )
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
...@@ -265,6 +266,7 @@ def _process_vlm( ...@@ -265,6 +266,7 @@ def _process_vlm(
f_dump_content_list, f_dump_content_list,
f_make_md_mode, f_make_md_mode,
server_url=None, server_url=None,
**kwargs,
): ):
"""同步处理VLM后端逻辑""" """同步处理VLM后端逻辑"""
parse_method = "vlm" parse_method = "vlm"
...@@ -278,7 +280,7 @@ def _process_vlm( ...@@ -278,7 +280,7 @@ def _process_vlm(
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = vlm_doc_analyze( middle_json, infer_result = vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
) )
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
...@@ -311,6 +313,7 @@ def do_parse( ...@@ -311,6 +313,7 @@ def do_parse(
f_make_md_mode=MakeMode.MM_MD, f_make_md_mode=MakeMode.MM_MD,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
**kwargs,
): ):
# 预处理PDF字节数据 # 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id) pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
...@@ -333,7 +336,7 @@ def do_parse( ...@@ -333,7 +336,7 @@ def do_parse(
output_dir, pdf_file_names, pdf_bytes_list, backend, output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json, f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode, f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url server_url, **kwargs,
) )
...@@ -357,6 +360,7 @@ async def aio_do_parse( ...@@ -357,6 +360,7 @@ async def aio_do_parse(
f_make_md_mode=MakeMode.MM_MD, f_make_md_mode=MakeMode.MM_MD,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
**kwargs,
): ):
# 预处理PDF字节数据 # 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id) pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
...@@ -380,7 +384,7 @@ async def aio_do_parse( ...@@ -380,7 +384,7 @@ async def aio_do_parse(
output_dir, pdf_file_names, pdf_bytes_list, backend, output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json, f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode, f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url server_url, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment