"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "eef921f45e7d3efb2ed2ccab80ee20ee2e4ebe38"
Unverified Commit 9c2492e5 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Misc] Support Human-readable (k/K/m/M..) json cli arg (#40473)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent fe57be78
...@@ -12,6 +12,7 @@ from pydantic import Field ...@@ -12,6 +12,7 @@ from pydantic import Field
from vllm.config import AttentionConfig, CompilationConfig, config from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import ( from vllm.engine.arg_utils import (
EngineArgs, EngineArgs,
_expand_json_human_readable_numbers,
contains_type, contains_type,
get_kwargs, get_kwargs,
get_type, get_type,
...@@ -563,3 +564,32 @@ def test_ir_op_priority(): ...@@ -563,3 +564,32 @@ def test_ir_op_priority():
ir_op_priority=ir_op_priority, ir_op_priority=ir_op_priority,
kernel_config=KernelConfig(ir_op_priority=ir_op_priority), kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
).create_engine_config() ).create_engine_config()
@pytest.mark.parametrize(
("input_json", "expected_json"),
[
# Decimal suffixes (lowercase)
('{"x": 80g}', '{"x": 80000000000}'),
('{"x": 1k}', '{"x": 1000}'),
('{"x": 5m}', '{"x": 5000000}'),
('{"x": 2t}', '{"x": 2000000000000}'),
# Binary suffixes (uppercase)
('{"x": 1K}', f'{{"x": {2**10}}}'),
('{"x": 1G}', f'{{"x": {2**30}}}'),
# Decimal values
('{"x": 1.5g}', '{"x": 1500000000}'),
# Quoted strings must NOT be modified
('{"my_key": 80g}', '{"my_key": 80000000000}'),
('{"name": "80g"}', '{"name": "80g"}'),
('{"model_name": "foo_bar"}', '{"model_name": "foo_bar"}'),
# Multiple values
('{"a": 1k, "b": 2m}', '{"a": 1000, "b": 2000000}'),
# Plain numbers are untouched
('{"x": 42}', '{"x": 42}'),
# Nested JSON
('{"outer": {"inner": 10g}}', '{"outer": {"inner": 10000000000}}'),
],
)
def test_expand_json_human_readable_numbers(input_json, expected_json):
assert _expand_json_human_readable_numbers(input_json) == expected_json
...@@ -105,7 +105,11 @@ from vllm.transformers_utils.config import ( ...@@ -105,7 +105,11 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.gguf_utils import is_gguf from vllm.transformers_utils.gguf_utils import is_gguf
from vllm.transformers_utils.repo_utils import get_model_path from vllm.transformers_utils.repo_utils import get_model_path
from vllm.transformers_utils.utils import is_cloud_storage from vllm.transformers_utils.utils import is_cloud_storage
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import (
FlexibleArgumentParser,
human_readable_int,
human_readable_int_or_auto,
)
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
...@@ -256,6 +260,28 @@ def _maybe_add_docs_url(cls: Any) -> str: ...@@ -256,6 +260,28 @@ def _maybe_add_docs_url(cls: Any) -> str:
return f"\n\nAPI docs: https://docs.vllm.ai/en/{version}/api/vllm/config/#vllm.config.{cls.__name__}" return f"\n\nAPI docs: https://docs.vllm.ai/en/{version}/api/vllm/config/#vllm.config.{cls.__name__}"
def _expand_json_human_readable_numbers(val: str) -> str:
"""Expand human-readable number suffixes in a JSON string.
Based on :func:`human_readable_int` so that the ``k/m/g/t`` (decimal) and
``K/M/G/T`` (binary) conventions work out the box.
Also works inside JSON config arguments such
as ``--kv-transfer-config '{"cpu_bytes_to_use": 80m}'``.
Only bare (unquoted) tokens are replaced so that JSON string values
like ``"model_name"`` are never modified.
"""
# Split on quoted strings so we only touch non-string regions.
parts = re.split(r'("(?:[^"\\]|\\.)*")', val)
for i in range(0, len(parts), 2): # even indices = outside strings
parts[i] = re.sub(
r"\b\d+(?:\.\d+)?[kKmMgGtT]\b",
lambda m: str(human_readable_int(m.group())),
parts[i],
)
return "".join(parts)
@functools.lru_cache(maxsize=30) @functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
# Save time only getting attr docs if we're generating help text # Save time only getting attr docs if we're generating help text
...@@ -301,6 +327,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: ...@@ -301,6 +327,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
def parse_dataclass(val: str, cls=dataclass_cls) -> Any: def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
try: try:
val = _expand_json_human_readable_numbers(val)
return TypeAdapter(cls).validate_json(val) return TypeAdapter(cls).validate_json(val)
except ValidationError as e: except ValidationError as e:
raise argparse.ArgumentTypeError(repr(e)) from e raise argparse.ArgumentTypeError(repr(e)) from e
...@@ -2419,68 +2446,3 @@ def _raise_unsupported_error(feature_name: str): ...@@ -2419,68 +2446,3 @@ def _raise_unsupported_error(feature_name: str):
f"remove {feature_name} from your config." f"remove {feature_name} from your config."
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
def human_readable_int(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
"""
value = value.strip()
match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
if match:
decimal_multiplier = {
"k": 10**3,
"m": 10**6,
"g": 10**9,
"t": 10**12,
}
binary_multiplier = {
"K": 2**10,
"M": 2**20,
"G": 2**30,
"T": 2**40,
}
number, suffix = match.groups()
if suffix in decimal_multiplier:
mult = decimal_multiplier[suffix]
return int(float(number) * mult)
elif suffix in binary_multiplier:
mult = binary_multiplier[suffix]
# Do not allow decimals with binary multipliers
try:
return int(number) * mult
except ValueError as e:
raise argparse.ArgumentTypeError(
"Decimals are not allowed "
f"with binary suffixes like {suffix}. Did you mean to use "
f"{number}{suffix.lower()} instead?"
) from e
# Regular plain number.
return int(value)
def human_readable_int_or_auto(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Also accepts -1 or 'auto' as a special value for auto-detection.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
- '-1' or 'auto' -> -1 (special value for auto-detection)
"""
value = value.strip()
if value == "-1" or value.lower() == "auto":
return -1
return human_readable_int(value)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Argument parsing utilities for vLLM.""" """Argument parsing utilities for vLLM."""
import argparse
import json import json
import sys import sys
import textwrap import textwrap
...@@ -25,6 +26,71 @@ from vllm.logger import init_logger ...@@ -25,6 +26,71 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def human_readable_int(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
"""
value = value.strip()
match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
if match:
decimal_multiplier = {
"k": 10**3,
"m": 10**6,
"g": 10**9,
"t": 10**12,
}
binary_multiplier = {
"K": 2**10,
"M": 2**20,
"G": 2**30,
"T": 2**40,
}
number, suffix = match.groups()
if suffix in decimal_multiplier:
mult = decimal_multiplier[suffix]
return int(float(number) * mult)
elif suffix in binary_multiplier:
mult = binary_multiplier[suffix]
# Do not allow decimals with binary multipliers
try:
return int(number) * mult
except ValueError as e:
raise argparse.ArgumentTypeError(
"Decimals are not allowed "
f"with binary suffixes like {suffix}. Did you mean to use "
f"{number}{suffix.lower()} instead?"
) from e
# Regular plain number.
return int(value)
def human_readable_int_or_auto(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Also accepts -1 or 'auto' as a special value for auto-detection.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
- '-1' or 'auto' -> -1 (special value for auto-detection)
"""
value = value.strip()
if value == "-1" or value.lower() == "auto":
return -1
return human_readable_int(value)
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings.""" """SortedHelpFormatter that sorts arguments by their option strings."""
...@@ -338,7 +404,12 @@ class FlexibleArgumentParser(ArgumentParser): ...@@ -338,7 +404,12 @@ class FlexibleArgumentParser(ArgumentParser):
try: try:
value = json.loads(value_str) value = json.loads(value_str)
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
value = value_str # Support human-readable suffixes (e.g. 1k, 80g) for
# dotted config args like --config.field 80g
try:
value = human_readable_int(value_str) # type: ignore[assignment]
except (ValueError, ArgumentTypeError):
value = value_str
# Merge all values with the same key into a single dict # Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value) arg_dict = create_nested_dict(keys, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment