Unverified Commit 7560ae5c authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[8/N] enable cli flag without a space (#10529)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent e7a8341c
...@@ -103,7 +103,7 @@ def test_compile_correctness(test_setting: TestSetting): ...@@ -103,7 +103,7 @@ def test_compile_correctness(test_setting: TestSetting):
CompilationLevel.NO_COMPILATION, CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE, CompilationLevel.PIECEWISE,
]: ]:
all_args.append(final_args + ["-O", str(level)]) all_args.append(final_args + [f"-O{level}"])
all_envs.append({}) all_envs.append({})
# inductor will change the output, so we only compare if the output # inductor will change the output, so we only compare if the output
...@@ -121,7 +121,7 @@ def test_compile_correctness(test_setting: TestSetting): ...@@ -121,7 +121,7 @@ def test_compile_correctness(test_setting: TestSetting):
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE, CompilationLevel.DYNAMO_ONCE,
]: ]:
all_args.append(final_args + ["-O", str(level)]) all_args.append(final_args + [f"-O{level}"])
all_envs.append({}) all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph # "DYNAMO_ONCE" will always use fullgraph
......
...@@ -31,6 +31,34 @@ def test_limit_mm_per_prompt_parser(arg, expected): ...@@ -31,6 +31,34 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected assert args.limit_mm_per_prompt == expected
def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# default value
args = parser.parse_args([])
assert args.compilation_config is None
# set to O3
args = parser.parse_args(["-O3"])
assert args.compilation_config.level == 3
# set to O 3 (space)
args = parser.parse_args(["-O", "3"])
assert args.compilation_config.level == 3
# set to O 3 (equals)
args = parser.parse_args(["-O=3"])
assert args.compilation_config.level == 3
# set to json
args = parser.parse_args(["--compilation-config", '{"level": 3}'])
assert args.compilation_config.level == 3
# set to json
args = parser.parse_args(['--compilation-config={"level": 3}'])
assert args.compilation_config.level == 3
def test_valid_pooling_config(): def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([ args = parser.parse_args([
......
...@@ -13,9 +13,10 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000" ...@@ -13,9 +13,10 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher(): def test_custom_dispatcher():
compare_two_settings( compare_two_settings(
"google/gemma-2b", "google/gemma-2b",
arg1=["--enforce-eager", "-O", arg1=[
str(CompilationLevel.DYNAMO_ONCE)], "--enforce-eager",
arg2=["--enforce-eager", "-O", f"-O{CompilationLevel.DYNAMO_ONCE}",
str(CompilationLevel.DYNAMO_AS_IS)], ],
arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"],
env1={}, env1={},
env2={}) env2={})
...@@ -882,7 +882,10 @@ class EngineArgs: ...@@ -882,7 +882,10 @@ class EngineArgs:
'testing only. level 3 is the recommended level ' 'testing only. level 3 is the recommended level '
'for production.\n' 'for production.\n'
'To specify the full compilation config, ' 'To specify the full compilation config, '
'use a JSON string.') 'use a JSON string.\n'
'Following the convention of traditional '
'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.')
return parser return parser
......
...@@ -1192,6 +1192,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -1192,6 +1192,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
else: else:
processed_args.append('--' + processed_args.append('--' +
arg[len('--'):].replace('_', '-')) arg[len('--'):].replace('_', '-'))
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3
processed_args.append('-O')
processed_args.append(arg[2:])
else: else:
processed_args.append(arg) processed_args.append(arg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment