Unverified Commit 6d42ce83 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[CLI] Improve CLI arg parsing for `-O`/`--compilation-config` (#20156)


Signed-off-by: default avatarluka <luka@neuralmagic.com>
parent ded1fb63
...@@ -239,32 +239,40 @@ def test_compilation_config(): ...@@ -239,32 +239,40 @@ def test_compilation_config():
assert args.compilation_config == CompilationConfig() assert args.compilation_config == CompilationConfig()
# set to O3 # set to O3
args = parser.parse_args(["-O3"]) args = parser.parse_args(["-O0"])
assert args.compilation_config.level == 3 assert args.compilation_config.level == 0
# set to O 3 (space) # set to O 3 (space)
args = parser.parse_args(["-O", "3"]) args = parser.parse_args(["-O", "1"])
assert args.compilation_config.level == 3 assert args.compilation_config.level == 1
# set to O 3 (equals) # set to O 3 (equals)
args = parser.parse_args(["-O=3"]) args = parser.parse_args(["-O=2"])
assert args.compilation_config.level == 2
# set to O.level 3
args = parser.parse_args(["-O.level", "3"])
assert args.compilation_config.level == 3 assert args.compilation_config.level == 3
# set to string form of a dict # set to string form of a dict
args = parser.parse_args([ args = parser.parse_args([
"--compilation-config", "-O",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": false}',
]) ])
assert (args.compilation_config.level == 3 and assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and not args.compilation_config.use_inductor)
# set to string form of a dict # set to string form of a dict
args = parser.parse_args([ args = parser.parse_args([
"--compilation-config=" "--compilation-config="
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": true}',
]) ])
assert (args.compilation_config.level == 3 and assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and args.compilation_config.use_inductor)
def test_prefix_cache_default(): def test_prefix_cache_default():
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import asyncio import asyncio
import hashlib import hashlib
import json import json
import logging
import pickle import pickle
import socket import socket
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
...@@ -142,6 +143,7 @@ def parser(): ...@@ -142,6 +143,7 @@ def parser():
parser.add_argument('--batch-size', type=int) parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true') parser.add_argument('--enable-feature', action='store_true')
parser.add_argument('--hf-overrides', type=json.loads) parser.add_argument('--hf-overrides', type=json.loads)
parser.add_argument('-O', '--compilation-config', type=json.loads)
return parser return parser
...@@ -265,6 +267,11 @@ def test_dict_args(parser): ...@@ -265,6 +267,11 @@ def test_dict_args(parser):
"val2", "val2",
"--hf-overrides.key2.key4", "--hf-overrides.key2.key4",
"val3", "val3",
# Test compile config and compilation level
"-O.use_inductor=true",
"-O.backend",
"custom",
"-O1",
# Test = sign # Test = sign
"--hf-overrides.key5=val4", "--hf-overrides.key5=val4",
# Test underscore to dash conversion # Test underscore to dash conversion
...@@ -281,6 +288,13 @@ def test_dict_args(parser): ...@@ -281,6 +288,13 @@ def test_dict_args(parser):
"true", "true",
"--hf_overrides.key12.key13", "--hf_overrides.key12.key13",
"null", "null",
# Test '-' and '.' in value
"--hf_overrides.key14.key15",
"-minus.and.dot",
# Test array values
"-O.custom_ops+",
"-quant_fp8",
"-O.custom_ops+=+silu_mul,-rms_norm",
] ]
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something" assert parsed_args.model_name == "something.something"
...@@ -301,9 +315,42 @@ def test_dict_args(parser): ...@@ -301,9 +315,42 @@ def test_dict_args(parser):
"key12": { "key12": {
"key13": None, "key13": None,
}, },
"key14": {
"key15": "-minus.and.dot",
}
}
assert parsed_args.compilation_config == {
"level": 1,
"use_inductor": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
} }
def test_duplicate_dict_args(caplog_vllm, parser):
args = [
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
"--hf-overrides.key1",
"val2",
"-O1",
"-O.level",
"2",
"-O3",
]
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
assert parsed_args.compilation_config == {"level": 3}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
assert "-O.level" in caplog_vllm.text
# yapf: enable # yapf: enable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
......
...@@ -4140,9 +4140,9 @@ class CompilationConfig: ...@@ -4140,9 +4140,9 @@ class CompilationConfig:
@classmethod @classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig": def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config.
if cli_value in ["0", "1", "2", "3"]: -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
return cls(level=int(cli_value)) """
return TypeAdapter(CompilationConfig).validate_json(cli_value) return TypeAdapter(CompilationConfig).validate_json(cli_value)
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -4303,17 +4303,16 @@ class VllmConfig: ...@@ -4303,17 +4303,16 @@ class VllmConfig:
"""Quantization configuration.""" """Quantization configuration."""
compilation_config: CompilationConfig = field( compilation_config: CompilationConfig = field(
default_factory=CompilationConfig) default_factory=CompilationConfig)
"""`torch.compile` configuration for the model. """`torch.compile` and cudagraph capture configuration for the model.
When it is a number (0, 1, 2, 3), it will be interpreted as the As a shorthand, `-O<n>` can be used to directly specify the compilation
optimization level. level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`).
Currently, -O <n> and -O=<n> are supported as well but this will likely be
removed in favor of clearer -O<n> syntax in the future.
NOTE: level 0 is the default level without any optimization. level 1 and 2 NOTE: level 0 is the default level without any optimization. level 1 and 2
are for internal testing only. level 3 is the recommended level for are for internal testing only. level 3 is the recommended level for
production. production, also default in V1.
Following the convention of traditional compilers, using `-O` without space
is also supported. `-O3` is equivalent to `-O 3`.
You can specify the full compilation config like so: You can specify the full compilation config like so:
`{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}`
......
...@@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -202,7 +202,10 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
passed individually. For example, the following sets of arguments are passed individually. For example, the following sets of arguments are
equivalent:\n\n equivalent:\n\n
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n
Additionally, list elements can be passed individually using '+':
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n"""
if dataclass_cls is not None: if dataclass_cls is not None:
def parse_dataclass(val: str, cls=dataclass_cls) -> Any: def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
......
...@@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 ...@@ -89,15 +89,15 @@ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
STR_NOT_IMPL_ENC_DEC_SWA = \ STR_NOT_IMPL_ENC_DEC_SWA = \
"Sliding window attention for encoder/decoder models " + \ "Sliding window attention for encoder/decoder models " + \
"is not currently supported." "is not currently supported."
STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
"Prefix caching for encoder/decoder models " + \ "Prefix caching for encoder/decoder models " + \
"is not currently supported." "is not currently supported."
STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
"Chunked prefill for encoder/decoder models " + \ "Chunked prefill for encoder/decoder models " + \
"is not currently supported." "is not currently supported."
STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
"Models with logits_soft_cap " "Models with logits_soft_cap "
...@@ -752,7 +752,7 @@ def _generate_random_fp8( ...@@ -752,7 +752,7 @@ def _generate_random_fp8(
# to generate random data for fp8 data. # to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format represents Inf. # For example, s.11111.00 in fp8e5m2 format represents Inf.
# | E4M3 | E5M2 # | E4M3 | E5M2
#-----|-------------|------------------- # -----|-------------|-------------------
# Inf | N/A | s.11111.00 # Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11} # NaN | s.1111.111 | s.11111.{01,10,11}
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -840,7 +840,6 @@ def create_kv_caches_with_random( ...@@ -840,7 +840,6 @@ def create_kv_caches_with_random(
seed: Optional[int] = None, seed: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16: if cache_dtype == "fp8" and head_size % 16:
raise ValueError( raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}" f"Does not support key cache of type fp8 with head_size {head_size}"
...@@ -1205,7 +1204,6 @@ def deprecate_args( ...@@ -1205,7 +1204,6 @@ def deprecate_args(
is_deprecated: Union[bool, Callable[[], bool]] = True, is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None, additional_message: Optional[str] = None,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
if not callable(is_deprecated): if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated) is_deprecated = partial(identity, is_deprecated)
...@@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: ...@@ -1355,7 +1353,7 @@ def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
return weak_bound return weak_bound
#From: https://stackoverflow.com/a/4104188/2749989 # From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]: def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
...@@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -1474,7 +1472,7 @@ class FlexibleArgumentParser(ArgumentParser):
# Convert underscores to dashes and vice versa in argument names # Convert underscores to dashes and vice versa in argument names
processed_args = list[str]() processed_args = list[str]()
for arg in args: for i, arg in enumerate(args):
if arg.startswith('--'): if arg.startswith('--'):
if '=' in arg: if '=' in arg:
key, value = arg.split('=', 1) key, value = arg.split('=', 1)
...@@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -1483,10 +1481,17 @@ class FlexibleArgumentParser(ArgumentParser):
else: else:
key = pattern.sub(repl, arg, count=1) key = pattern.sub(repl, arg, count=1)
processed_args.append(key) processed_args.append(key)
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: elif arg.startswith('-O') and arg != '-O' and arg[2] != '.':
# allow -O flag to be used without space, e.g. -O3 # allow -O flag to be used without space, e.g. -O3 or -Odecode
processed_args.append('-O') # -O.<...> handled later
processed_args.append(arg[2:]) # also handle -O=<level> here
level = arg[3:] if arg[2] == '=' else arg[2:]
processed_args.append(f'-O.level={level}')
elif arg == '-O' and i + 1 < len(args) and args[i + 1] in {
"0", "1", "2", "3"
}:
# Convert -O <n> to -O.level <n>
processed_args.append('-O.level')
else: else:
processed_args.append(arg) processed_args.append(arg)
...@@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -1504,27 +1509,44 @@ class FlexibleArgumentParser(ArgumentParser):
def recursive_dict_update( def recursive_dict_update(
original: dict[str, Any], original: dict[str, Any],
update: dict[str, Any], update: dict[str, Any],
): ) -> set[str]:
"""Recursively updates a dictionary with another dictionary.""" """Recursively updates a dictionary with another dictionary.
Returns a set of duplicate keys that were overwritten.
"""
duplicates = set[str]()
for k, v in update.items(): for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict): if isinstance(v, dict) and isinstance(original.get(k), dict):
recursive_dict_update(original[k], v) nested_duplicates = recursive_dict_update(original[k], v)
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
elif isinstance(v, list) and isinstance(original.get(k), list):
original[k] += v
else: else:
if k in original:
duplicates.add(k)
original[k] = v original[k] = v
return duplicates
delete = set[int]() delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict) dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
for i, processed_arg in enumerate(processed_args): for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith("--") and "." in processed_arg: if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg: if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1) processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg: if "." not in processed_arg:
# False positive, . was only in the value # False positive, '.' was only in the value
continue continue
else: else:
value_str = processed_args[i + 1] value_str = processed_args[i + 1]
delete.add(i + 1) delete.add(i + 1)
if processed_arg.endswith("+"):
processed_arg = processed_arg[:-1]
value_str = json.dumps(list(value_str.split(",")))
key, *keys = processed_arg.split(".") key, *keys = processed_arg.split(".")
try: try:
value = json.loads(value_str) value = json.loads(value_str)
...@@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -1533,12 +1555,17 @@ class FlexibleArgumentParser(ArgumentParser):
# Merge all values with the same key into a single dict # Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value) arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict) arg_duplicates = recursive_dict_update(dict_args[key],
arg_dict)
duplicates |= {f'{key}.{d}' for d in arg_duplicates}
delete.add(i) delete.add(i)
# Filter out the dict args we set to None # Filter out the dict args we set to None
processed_args = [ processed_args = [
a for i, a in enumerate(processed_args) if i not in delete a for i, a in enumerate(processed_args) if i not in delete
] ]
if duplicates:
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
# Add the dict args back as if they were originally passed as JSON # Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items(): for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg) processed_args.append(dict_arg)
...@@ -2405,7 +2432,7 @@ def memory_profiling( ...@@ -2405,7 +2432,7 @@ def memory_profiling(
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa """ # noqa
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment