Unverified Commit 5df18348 authored by Andy Dai's avatar Andy Dai Committed by GitHub
Browse files

[Bugfix] Fix order of arguments matters in config.yaml (#8960)

parent cfadb9c6
...@@ -140,7 +140,7 @@ $ vllm serve SOME_MODEL --config config.yaml ...@@ -140,7 +140,7 @@ $ vllm serve SOME_MODEL --config config.yaml
``` ```
--- ---
**NOTE** **NOTE**
In case an argument is supplied using command line and the config file, the value from the commandline will take precedence. In case an argument is supplied simultaneously using command line and the config file, the value from the commandline will take precedence.
The order of priorities is `command line > config file values > defaults`. The order of priorities is `command line > config file values > defaults`.
--- ---
......
port: 12312 port: 12312
served_model_name: mymodel
tensor_parallel_size: 2 tensor_parallel_size: 2
...@@ -136,6 +136,8 @@ def parser(): ...@@ -136,6 +136,8 @@ def parser():
def parser_with_config(): def parser_with_config():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument('serve') parser.add_argument('serve')
parser.add_argument('model_tag')
parser.add_argument('--served-model-name', type=str)
parser.add_argument('--config', type=str) parser.add_argument('--config', type=str)
parser.add_argument('--port', type=int) parser.add_argument('--port', type=int)
parser.add_argument('--tensor-parallel-size', type=int) parser.add_argument('--tensor-parallel-size', type=int)
...@@ -190,33 +192,47 @@ def test_missing_required_argument(parser): ...@@ -190,33 +192,47 @@ def test_missing_required_argument(parser):
def test_cli_override_to_config(parser_with_config): def test_cli_override_to_config(parser_with_config):
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', '--config', './data/test_config.yaml', 'serve', 'mymodel', '--config', './data/test_config.yaml',
'--tensor-parallel-size', '3' '--tensor-parallel-size', '3'
]) ])
assert args.tensor_parallel_size == 3 assert args.tensor_parallel_size == 3
args = parser_with_config.parse_args([ args = parser_with_config.parse_args([
'serve', '--tensor-parallel-size', '3', '--config', 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml' './data/test_config.yaml'
]) ])
assert args.tensor_parallel_size == 3 assert args.tensor_parallel_size == 3
assert args.port == 12312
args = parser_with_config.parse_args([
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'./data/test_config.yaml', '--port', '666'
])
assert args.tensor_parallel_size == 3
assert args.port == 666
def test_config_args(parser_with_config): def test_config_args(parser_with_config):
args = parser_with_config.parse_args( args = parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml']) ['serve', 'mymodel', '--config', './data/test_config.yaml'])
assert args.tensor_parallel_size == 2 assert args.tensor_parallel_size == 2
def test_config_file(parser_with_config): def test_config_file(parser_with_config):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
parser_with_config.parse_args(['serve', '--config', 'test_config.yml']) parser_with_config.parse_args(
['serve', 'mymodel', '--config', 'test_config.yml'])
with pytest.raises(ValueError): with pytest.raises(ValueError):
parser_with_config.parse_args( parser_with_config.parse_args(
['serve', '--config', './data/test_config.json']) ['serve', 'mymodel', '--config', './data/test_config.json'])
with pytest.raises(ValueError): with pytest.raises(ValueError):
parser_with_config.parse_args([ parser_with_config.parse_args([
'serve', '--tensor-parallel-size', '3', '--config', '--batch-size', 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
'32' '--batch-size', '32'
]) ])
def test_no_model_tag(parser_with_config):
with pytest.raises(ValueError):
parser_with_config.parse_args(
['serve', '--config', './data/test_config.yaml'])
...@@ -1201,11 +1201,21 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -1201,11 +1201,21 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args = FlexibleArgumentParser._load_config_file(file_path) config_args = FlexibleArgumentParser._load_config_file(file_path)
# 0th index is for {serve,chat,complete} # 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
# followed by config args # followed by config args
# followed by rest of cli args. # followed by rest of cli args.
# maintaining this order will enforce the precedence # maintaining this order will enforce the precedence
# of cli > config > defaults # of cli > config > defaults
args = [args[0]] + config_args + args[1:index] + args[index + 2:] if args[0] == "serve":
if index == 1:
raise ValueError(
"No model_tag specified! Please check your command-line"
" arguments.")
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment