Unverified Commit 28c5e69b authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Enforce that `model` is the first positional arg when `--served-model-name` is used (#34973)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 864167d3
......@@ -20,10 +20,22 @@ CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
assert CHATML_JINJA_PATH.exists()
def _build_vllm_parsers():
vllm_parser = FlexibleArgumentParser()
subparsers = vllm_parser.add_subparsers()
serve_parser = subparsers.add_parser("serve")
make_arg_parser(serve_parser)
return {"vllm": vllm_parser, "vllm serve": serve_parser}
@pytest.fixture
def vllm_parser():
return _build_vllm_parsers()["vllm"]
@pytest.fixture
def serve_parser():
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
return make_arg_parser(parser)
return _build_vllm_parsers()["vllm serve"]
### Test config parsing
......@@ -241,3 +253,41 @@ def test_default_chat_template_kwargs_invalid_json(serve_parser):
serve_parser.parse_args(
args=["--default-chat-template-kwargs", "not valid json"]
)
@pytest.mark.parametrize(
"args, raises",
[
(["user/model"], None),
(["user/model", "--served-model-name", "model"], None),
(["--served-model-name", "model", "user/model"], ValueError),
(["--served-model-name", "model", "--config", "config.yaml"], None),
(["--served-model-name", "model", "--config", "config.yaml"], ValueError),
],
ids=[
"model_tag_only",
"model_tag_with_served_model_name",
"served_model_name_before_model_tag",
"served_model_name_with_model_in_config",
"served_model_name_with_no_model_in_config",
],
)
def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
"""Ensure that users don't misuse --served-model-name and end up with the default
model tag instead of the one they intended to serve."""
# Call the serve subparser
args.insert(0, "serve")
# Create a dummy config file if the test case includes it
if "config.yaml" in args:
# Create a dummy config file if the test case includes it
config_path = tmp_path / "config.yaml"
config_path.write_text("model: user/model" if raises is None else "port: 8000")
args[args.index("config.yaml")] = config_path.as_posix()
# Do the parsing and check for expected exceptions or values
if raises is None:
parsed_args = vllm_parser.parse_args(args=args)
expected = "user/model"
assert parsed_args.model_tag == expected or parsed_args.model == expected
else:
with pytest.raises(raises):
vllm_parser.parse_args(args=args)
......@@ -184,13 +184,11 @@ class FlexibleArgumentParser(ArgumentParser):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
# Check for --model in command line arguments first
try:
model_idx = next(
i
for i, arg in enumerate(args)
if arg == "--model" or arg.startswith("--model=")
i for i, arg in enumerate(args) if re.match(r"^--model(=.+|$)", arg)
)
logger.warning(
"With `vllm serve`, you should provide the model as a "
......@@ -219,6 +217,19 @@ class FlexibleArgumentParser(ArgumentParser):
]
except StopIteration:
pass
# Check for --served-model-name without a positional model argument
if (
len(args) > 1
and args[1].startswith("-")
and not any(re.match(r"^--config(=.+|$)", arg) for arg in args)
and any(
re.match(r"^--served[-_]model[-_]name(=.+|$)", arg) for arg in args
)
):
raise ValueError(
"`model` should be provided as the first positional argument when "
"using `vllm serve`. i.e. `vllm serve <model> --<arg> <value>`."
)
if "--config" in args:
args = self._pull_args_from_config(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment