Unverified Commit 0c5f00fa authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2881 from myhloli/dev

parents 8aac6107 b943f04e
......@@ -47,9 +47,10 @@ def doc_analyze(
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
......@@ -73,9 +74,10 @@ async def aio_doc_analyze(
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
......
......@@ -9,7 +9,8 @@ from mineru.utils.model_utils import get_vram
from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command()
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.version_option(__version__,
'--version',
'-v',
......@@ -137,7 +138,49 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
)
def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
def main(
ctx,
input_path, output_dir, method, backend, lang, server_url,
start_page_id, end_page_id, formula_enable, table_enable,
device_mode, virtual_vram, model_source, **kwargs
):
# 解析额外参数
extra_kwargs = {}
i = 0
while i < len(ctx.args):
arg = ctx.args[i]
if arg.startswith('--'):
param_name = arg[2:].replace('-', '_') # 转换参数名格式
i += 1
if i < len(ctx.args) and not ctx.args[i].startswith('--'):
# 参数有值
try:
# 尝试转换为适当的类型
if ctx.args[i].lower() == 'true':
extra_kwargs[param_name] = True
elif ctx.args[i].lower() == 'false':
extra_kwargs[param_name] = False
elif '.' in ctx.args[i]:
try:
extra_kwargs[param_name] = float(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
else:
try:
extra_kwargs[param_name] = int(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
except:
extra_kwargs[param_name] = ctx.args[i]
else:
# 布尔型标志参数
extra_kwargs[param_name] = True
i -= 1
i += 1
# 将解析出的参数合并到kwargs
kwargs.update(extra_kwargs)
if not backend.endswith('-client'):
def get_device_mode() -> str:
......@@ -184,7 +227,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
table_enable=table_enable,
server_url=server_url,
start_page_id=start_page_id,
end_page_id=end_page_id
end_page_id=end_page_id,
**kwargs,
)
except Exception as e:
logger.exception(e)
......
......@@ -225,6 +225,7 @@ async def _async_process_vlm(
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""异步处理VLM后端逻辑"""
parse_method = "vlm"
......@@ -238,7 +239,7 @@ async def _async_process_vlm(
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = await aio_vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
......@@ -265,6 +266,7 @@ def _process_vlm(
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""同步处理VLM后端逻辑"""
parse_method = "vlm"
......@@ -278,7 +280,7 @@ def _process_vlm(
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
......@@ -311,6 +313,7 @@ def do_parse(
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
......@@ -333,7 +336,7 @@ def do_parse(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url
server_url, **kwargs,
)
......@@ -357,6 +360,7 @@ async def aio_do_parse(
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
......@@ -380,7 +384,7 @@ async def aio_do_parse(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url
server_url, **kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment