Unverified Commit 9c2492e5 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Misc] Support Human-readable (k/K/m/M..) json cli arg (#40473)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent fe57be78
......@@ -12,6 +12,7 @@ from pydantic import Field
from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import (
EngineArgs,
_expand_json_human_readable_numbers,
contains_type,
get_kwargs,
get_type,
......@@ -563,3 +564,32 @@ def test_ir_op_priority():
ir_op_priority=ir_op_priority,
kernel_config=KernelConfig(ir_op_priority=ir_op_priority),
).create_engine_config()
@pytest.mark.parametrize(
("input_json", "expected_json"),
[
# Decimal suffixes (lowercase)
('{"x": 80g}', '{"x": 80000000000}'),
('{"x": 1k}', '{"x": 1000}'),
('{"x": 5m}', '{"x": 5000000}'),
('{"x": 2t}', '{"x": 2000000000000}'),
# Binary suffixes (uppercase)
('{"x": 1K}', f'{{"x": {2**10}}}'),
('{"x": 1G}', f'{{"x": {2**30}}}'),
# Decimal values
('{"x": 1.5g}', '{"x": 1500000000}'),
# Quoted strings must NOT be modified
('{"my_key": 80g}', '{"my_key": 80000000000}'),
('{"name": "80g"}', '{"name": "80g"}'),
('{"model_name": "foo_bar"}', '{"model_name": "foo_bar"}'),
# Multiple values
('{"a": 1k, "b": 2m}', '{"a": 1000, "b": 2000000}'),
# Plain numbers are untouched
('{"x": 42}', '{"x": 42}'),
# Nested JSON
('{"outer": {"inner": 10g}}', '{"outer": {"inner": 10000000000}}'),
],
)
def test_expand_json_human_readable_numbers(input_json, expected_json):
assert _expand_json_human_readable_numbers(input_json) == expected_json
......@@ -105,7 +105,11 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.gguf_utils import is_gguf
from vllm.transformers_utils.repo_utils import get_model_path
from vllm.transformers_utils.utils import is_cloud_storage
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import (
FlexibleArgumentParser,
human_readable_int,
human_readable_int_or_auto,
)
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
......@@ -256,6 +260,28 @@ def _maybe_add_docs_url(cls: Any) -> str:
return f"\n\nAPI docs: https://docs.vllm.ai/en/{version}/api/vllm/config/#vllm.config.{cls.__name__}"
def _expand_json_human_readable_numbers(val: str) -> str:
"""Expand human-readable number suffixes in a JSON string.
Based on :func:`human_readable_int` so that the ``k/m/g/t`` (decimal) and
``K/M/G/T`` (binary) conventions work out the box.
Also works inside JSON config arguments such
as ``--kv-transfer-config '{"cpu_bytes_to_use": 80m}'``.
Only bare (unquoted) tokens are replaced so that JSON string values
like ``"model_name"`` are never modified.
"""
# Split on quoted strings so we only touch non-string regions.
parts = re.split(r'("(?:[^"\\]|\\.)*")', val)
for i in range(0, len(parts), 2): # even indices = outside strings
parts[i] = re.sub(
r"\b\d+(?:\.\d+)?[kKmMgGtT]\b",
lambda m: str(human_readable_int(m.group())),
parts[i],
)
return "".join(parts)
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
# Save time only getting attr docs if we're generating help text
......@@ -301,6 +327,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]:
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
try:
val = _expand_json_human_readable_numbers(val)
return TypeAdapter(cls).validate_json(val)
except ValidationError as e:
raise argparse.ArgumentTypeError(repr(e)) from e
......@@ -2419,68 +2446,3 @@ def _raise_unsupported_error(feature_name: str):
f"remove {feature_name} from your config."
)
raise NotImplementedError(msg)
def human_readable_int(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
"""
value = value.strip()
match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
if match:
decimal_multiplier = {
"k": 10**3,
"m": 10**6,
"g": 10**9,
"t": 10**12,
}
binary_multiplier = {
"K": 2**10,
"M": 2**20,
"G": 2**30,
"T": 2**40,
}
number, suffix = match.groups()
if suffix in decimal_multiplier:
mult = decimal_multiplier[suffix]
return int(float(number) * mult)
elif suffix in binary_multiplier:
mult = binary_multiplier[suffix]
# Do not allow decimals with binary multipliers
try:
return int(number) * mult
except ValueError as e:
raise argparse.ArgumentTypeError(
"Decimals are not allowed "
f"with binary suffixes like {suffix}. Did you mean to use "
f"{number}{suffix.lower()} instead?"
) from e
# Regular plain number.
return int(value)
def human_readable_int_or_auto(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Also accepts -1 or 'auto' as a special value for auto-detection.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
- '-1' or 'auto' -> -1 (special value for auto-detection)
"""
value = value.strip()
if value == "-1" or value.lower() == "auto":
return -1
return human_readable_int(value)
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Argument parsing utilities for vLLM."""
import argparse
import json
import sys
import textwrap
......@@ -25,6 +26,71 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def human_readable_int(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
"""
value = value.strip()
match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value)
if match:
decimal_multiplier = {
"k": 10**3,
"m": 10**6,
"g": 10**9,
"t": 10**12,
}
binary_multiplier = {
"K": 2**10,
"M": 2**20,
"G": 2**30,
"T": 2**40,
}
number, suffix = match.groups()
if suffix in decimal_multiplier:
mult = decimal_multiplier[suffix]
return int(float(number) * mult)
elif suffix in binary_multiplier:
mult = binary_multiplier[suffix]
# Do not allow decimals with binary multipliers
try:
return int(number) * mult
except ValueError as e:
raise argparse.ArgumentTypeError(
"Decimals are not allowed "
f"with binary suffixes like {suffix}. Did you mean to use "
f"{number}{suffix.lower()} instead?"
) from e
# Regular plain number.
return int(value)
def human_readable_int_or_auto(value: str) -> int:
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Also accepts -1 or 'auto' as a special value for auto-detection.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
- '-1' or 'auto' -> -1 (special value for auto-detection)
"""
value = value.strip()
if value == "-1" or value.lower() == "auto":
return -1
return human_readable_int(value)
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
......@@ -338,6 +404,11 @@ class FlexibleArgumentParser(ArgumentParser):
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
# Support human-readable suffixes (e.g. 1k, 80g) for
# dotted config args like --config.field 80g
try:
value = human_readable_int(value_str) # type: ignore[assignment]
except (ValueError, ArgumentTypeError):
value = value_str
# Merge all values with the same key into a single dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment