"docker/install/vscode:/vscode.git/clone" did not exist on "fa3fbbfb8f51f23369af3ebc8f5c9ab0587a02b8"
Commit 54913fc2 authored by myhloli's avatar myhloli
Browse files

refactor: enhance command-line argument parsing to support additional parameters in gradio_app.py

parent 09eecea4
......@@ -188,7 +188,8 @@ def update_interface(backend_choice):
pass
@click.command()
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option(
'--enable-example',
'example_enable',
......@@ -204,20 +205,6 @@ def update_interface(backend_choice):
help="Enable SgLang engine backend for faster processing.",
default=False,
)
@click.option(
'--mem-fraction-static',
'mem_fraction_static',
type=float,
help="Set the static memory fraction for SgLang engine. ",
default=None,
)
@click.option(
'--enable-torch-compile',
'torch_compile_enable',
type=bool,
help="Enable torch compile for SgLang engine. ",
default=False,
)
@click.option(
'--enable-api',
'api_enable',
......@@ -246,28 +233,57 @@ def update_interface(backend_choice):
help="Set the server port for the Gradio app.",
default=None,
)
def main(
example_enable, sglang_engine_enable, mem_fraction_static, torch_compile_enable, api_enable, max_convert_pages,
server_name, server_port
def main(ctx,
example_enable, sglang_engine_enable, api_enable, max_convert_pages,
server_name, server_port, **kwargs
):
# 解析额外参数
extra_kwargs = {}
i = 0
while i < len(ctx.args):
arg = ctx.args[i]
if arg.startswith('--'):
param_name = arg[2:].replace('-', '_') # 转换参数名格式
i += 1
if i < len(ctx.args) and not ctx.args[i].startswith('--'):
# 参数有值
try:
# 尝试转换为适当的类型
if ctx.args[i].lower() == 'true':
extra_kwargs[param_name] = True
elif ctx.args[i].lower() == 'false':
extra_kwargs[param_name] = False
elif '.' in ctx.args[i]:
try:
extra_kwargs[param_name] = float(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
else:
try:
extra_kwargs[param_name] = int(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
except:
extra_kwargs[param_name] = ctx.args[i]
else:
# 布尔型标志参数
extra_kwargs[param_name] = True
i -= 1
i += 1
# 将解析出的参数合并到kwargs
kwargs.update(extra_kwargs)
if sglang_engine_enable:
try:
print("Start init SgLang engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton
model_singleton = ModelSingleton()
model_params = {
"enable_torch_compile": torch_compile_enable
}
# 只有当mem_fraction_static不为None时才添加该参数
if mem_fraction_static is not None:
model_params["mem_fraction_static"] = mem_fraction_static
predictor = model_singleton.get_model(
"sglang-engine",
None,
None,
**model_params
**kwargs
)
print("SgLang engine init successfully.")
except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment