Unverified Commit 1c1bb388 authored by Alex Brooks's avatar Alex Brooks Committed by GitHub
Browse files

[Frontend] Improve Nullable kv Arg Parsing (#8525)


Signed-off-by: default avatarAlex-Brooks <Alex.Brooks@ibm.com>
parent 546034b4
from argparse import ArgumentTypeError
import pytest import pytest
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -13,6 +15,10 @@ from vllm.utils import FlexibleArgumentParser ...@@ -13,6 +15,10 @@ from vllm.utils import FlexibleArgumentParser
"image": 16, "image": 16,
"video": 2 "video": 2
}), }),
("Image=16, Video=2", {
"image": 16,
"video": 2
}),
]) ])
def test_limit_mm_per_prompt_parser(arg, expected): def test_limit_mm_per_prompt_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
...@@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected): ...@@ -22,3 +28,15 @@ def test_limit_mm_per_prompt_parser(arg, expected):
args = parser.parse_args(["--limit-mm-per-prompt", arg]) args = parser.parse_args(["--limit-mm-per-prompt", arg])
assert args.limit_mm_per_prompt == expected assert args.limit_mm_per_prompt == expected
@pytest.mark.parametrize(
("arg"),
[
"image", # Missing =
"image=4,image=5", # Conflicting values
"image=video=4" # Too many = in tokenized arg
])
def test_bad_nullable_kvs(arg):
with pytest.raises(ArgumentTypeError):
nullable_kvs(arg)
...@@ -44,22 +44,36 @@ def nullable_str(val: str): ...@@ -44,22 +44,36 @@ def nullable_str(val: str):
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
val: String value to be parsed.
Returns:
Dictionary with parsed values.
"""
if len(val) == 0: if len(val) == 0:
return None return None
out_dict: Dict[str, int] = {} out_dict: Dict[str, int] = {}
for item in val.split(","): for item in val.split(","):
try: kv_parts = [part.lower().strip() for part in item.split("=")]
key, value = item.split("=") if len(kv_parts) != 2:
except TypeError as exc: raise argparse.ArgumentTypeError(
msg = "Each item should be in the form KEY=VALUE" "Each item should be in the form KEY=VALUE")
raise ValueError(msg) from exc key, value = kv_parts
try: try:
out_dict[key] = int(value) parsed_value = int(value)
except ValueError as exc: except ValueError as exc:
msg = f"Failed to parse value of item {key}={value}" msg = f"Failed to parse value of item {key}={value}"
raise ValueError(msg) from exc raise argparse.ArgumentTypeError(msg) from exc
if key in out_dict and out_dict[key] != parsed_value:
raise argparse.ArgumentTypeError(
f"Conflicting values specified for key: {key}")
out_dict[key] = parsed_value
return out_dict return out_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment