Unverified Commit 71b1c8b6 authored by Yeshwanth N's avatar Yeshwanth N Committed by GitHub
Browse files

[Chore]:Extract math and argparse utilities to separate modules (#27188)


Signed-off-by: default avatarYeshwanth Surya <yeshsurya@gmail.com>
Signed-off-by: default avatarYeshwanth N <yeshsurya@gmail.com>
Signed-off-by: default avataryeshsurya <yeshsurya@gmail.com>
parent 8fb7b2fa
......@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils import cdiv, round_up
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
......
......@@ -32,7 +32,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
......
......@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import (
is_pin_memory_available,
is_uva_available,
......
......@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
FlexibleArgumentParser = object
......
......@@ -6,42 +6,56 @@ import datetime
import enum
import getpass
import inspect
import json
import multiprocessing
import os
import signal
import sys
import tempfile
import textwrap
import threading
import traceback
import uuid
import warnings
import weakref
from argparse import (
Action,
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from collections import defaultdict
from collections.abc import Callable
from functools import partial, wraps
from functools import cache, partial, wraps
from typing import TYPE_CHECKING, Any, TypeVar
import cloudpickle
import psutil
import regex as re
import torch
import yaml
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.ray.lazy_utils import is_in_ray_actor
# Import utilities from specialized modules for backward compatibility
from vllm.utils.argparse_utils import (
FlexibleArgumentParser,
SortedHelpFormatter,
StoreBoolean,
)
from vllm.utils.math_utils import (
cdiv,
next_power_of_2,
prev_power_of_2,
round_down,
round_up,
)
from vllm.utils.platform_utils import cuda_is_initialized, xpu_is_initialized
__all__ = [
# Argparse utilities
"FlexibleArgumentParser",
"SortedHelpFormatter",
"StoreBoolean",
# Math utilities
"cdiv",
"next_power_of_2",
"prev_power_of_2",
"round_down",
"round_up",
]
_DEPRECATED_MAPPINGS = {
"cprofile": "profiling",
"cprofile_context": "profiling",
......@@ -139,31 +153,31 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def next_power_of_2(n) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()
def update_environment_variables(envs: dict[str, str]):
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
"Overwriting environment variable %s from '%s' to '%s'",
k,
os.environ[k],
v,
)
os.environ[k] = v
def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)
@cache
def is_pin_memory_available() -> bool:
from vllm.platforms import current_platform
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
return current_platform.is_pin_memory_available()
def round_down(x: int, y: int) -> int:
return (x // y) * y
@cache
def is_uva_available() -> bool:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return is_pin_memory_available()
# TODO: This function can be removed if transformer_modules classes are
......@@ -214,488 +228,6 @@ def weak_bind(
return weak_bound
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(
f"Invalid boolean value: {values}. Expected 'true' or 'false'."
)
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile(r"\n{2,}\s*")
text = single_newline.sub(" ", text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions)
class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
_deprecated: set[Action] = set()
_json_tip: str = (
"When passing JSON CLI arguments, the following sets of arguments "
"are equivalent:\n"
' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
" --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
"Additionally, list elements can be passed individually using +:\n"
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
)
_search_keyword: str | None = None
def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = SortedHelpFormatter
# Pop kwarg "add_json_tip" to control whether to add the JSON tip
self.add_json_tip = kwargs.pop("add_json_tip", True)
super().__init__(*args, **kwargs)
if sys.version_info < (3, 13):
# Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None):
if args is not None and "--disable-log-requests" in args:
# Special case warning because the warning below won't trigger
# if –-disable-log-requests because its value is default.
logger.warning_once(
"argument '--disable-log-requests' is deprecated and "
"replaced with '--enable-log-requests'. This will be "
"removed in v0.12.0."
)
namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated:
if (
hasattr(namespace, dest := action.dest)
and getattr(namespace, dest) != action.default
):
logger.warning_once("argument '%s' is deprecated", dest)
return namespace, args
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
class _FlexibleArgumentGroup(_ArgumentGroup):
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
def add_argument_group(self, *args, **kwargs):
group = self._FlexibleArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def format_help(self):
# Only use custom help formatting for bottom level parsers
if self._subparsers is not None:
return super().format_help()
formatter = self._get_formatter()
# Handle keyword search of the args
if (search_keyword := self._search_keyword) is not None:
# Normalise the search keyword
search_keyword = search_keyword.lower().replace("_", "-")
# Return full help if searching for 'all'
if search_keyword == "all":
self.epilog = self._json_tip
return super().format_help()
# Return group help if searching for a group title
for group in self._action_groups:
if group.title and group.title.lower() == search_keyword:
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# Return matched args if searching for an arg name
matched_actions = []
for group in self._action_groups:
for action in group._group_actions:
# search option name
if any(
search_keyword in opt.lower() for opt in action.option_strings
):
matched_actions.append(action)
if matched_actions:
formatter.start_section(f"Arguments matching '{search_keyword}'")
formatter.add_arguments(matched_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# No match found
formatter.add_text(
f"No group or arguments matching '{search_keyword}'.\n"
"Use '--help' to see available groups or "
"'--help=all' to see all available parameters."
)
return formatter.format_help()
# usage
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
# description
formatter.add_text(self.description)
# positionals, optionals and user-defined groups
formatter.start_section("Config Groups")
config_groups = ""
for group in self._action_groups:
if not group._group_actions:
continue
title = group.title
description = group.description or ""
config_groups += f"{title: <24}{description}\n"
formatter.add_text(config_groups)
formatter.end_section()
# epilog
formatter.add_text(self.epilog)
# determine help from format above
return formatter.format_help()
def parse_args( # type: ignore[override]
self,
args: list[str] | None = None,
namespace: Namespace | None = None,
):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
try:
model_idx = next(
i
for i, arg in enumerate(args)
if arg == "--model" or arg.startswith("--model=")
)
logger.warning(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option. "
"The `--model` option will be removed in v0.13."
)
if args[model_idx] == "--model":
model_tag = args[model_idx + 1]
rest_start_idx = model_idx + 2
else:
model_tag = args[model_idx].removeprefix("--model=")
rest_start_idx = model_idx + 1
# Move <model> to the front, e,g:
# [Before]
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
# [After]
# vllm serve <model> -tp 2 --enforce-eager --port 8001
args = [
"serve",
model_tag,
*args[1:model_idx],
*args[rest_start_idx:],
]
print("args", args)
except StopIteration:
pass
if "--config" in args:
args = self._pull_args_from_config(args)
def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace("_", "-")
# Everything between the first -- and the first .
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names
processed_args = list[str]()
for i, arg in enumerate(args):
if arg.startswith("--help="):
FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
processed_args.append("--help")
elif arg.startswith("--"):
if "=" in arg:
key, value = arg.split("=", 1)
key = pattern.sub(repl, key, count=1)
processed_args.append(f"{key}={value}")
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
else:
processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
`{"a": {"b": {"c": 1}}}`
"""
nested_dict: Any = value
for key in reversed(keys):
nested_dict = {key: nested_dict}
return nested_dict
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
) -> set[str]:
"""Recursively updates a dictionary with another dictionary.
Returns a set of duplicate keys that were overwritten.
"""
duplicates = set[str]()
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
nested_duplicates = recursive_dict_update(original[k], v)
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
elif isinstance(v, list) and isinstance(original.get(k), list):
original[k] += v
else:
if k in original:
duplicates.add(k)
original[k] = v
return duplicates
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, '.' was only in the value
continue
else:
value_str = processed_args[i + 1]
delete.add(i + 1)
if processed_arg.endswith("+"):
processed_arg = processed_arg[:-1]
value_str = json.dumps(list(value_str.split(",")))
key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
duplicates |= {f"{key}.{d}" for d in arg_duplicates}
delete.add(i)
# Filter out the dict args we set to None
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
if duplicates:
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
processed_args.append(json.dumps(dict_value))
return super().parse_args(processed_args, namespace)
def check_port(self, value):
try:
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value
def _pull_args_from_config(self, args: list[str]) -> list[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
The arguments in config file will be inserted between
the argument list.
example:
```yaml
port: 12323
tensor-parallel-size: 4
```
```python
$: vllm {serve,chat,complete} "facebook/opt-12B" \
--config config.yaml -tp 2
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--config', 'config.yaml',
'-tp', '2'
]
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
'-tp', '2'
]
```
Please note how the config args are inserted after the sub command.
this way the order of priorities is maintained when these are args
parsed by super().
"""
assert args.count("--config") <= 1, "More than one config file specified!"
index = args.index("--config")
if index == len(args) - 1:
raise ValueError(
"No config file specified! \
Please check your command-line arguments."
)
file_path = args[index + 1]
config_args = self.load_config_file(file_path)
# 0th index might be the sub command {serve,chat,complete,...}
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0].startswith("-"):
# No sub command (e.g., api_server entry point)
args = config_args + args[0:index] + args[index + 2 :]
elif args[0] == "serve":
model_in_cli = len(args) > 1 and not args[1].startswith("-")
model_in_config = any(arg == "--model" for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model specified! Please specify model either "
"as a positional argument or in a config file."
)
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = (
[args[0]]
+ [args[1]]
+ config_args
+ args[2:index]
+ args[index + 2 :]
)
else:
# No model in CLI, use config if available
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
return args
def load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
port: 12323
tensor-parallel-size: 4
```
returns:
processed_args: list[str] = [
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split(".")[-1]
if extension not in ("yaml", "yml"):
raise ValueError(
"Config file must be of a yaml/yml type.\
%s supplied",
extension,
)
# only expecting a flat dictionary of atomic types
processed_args: list[str] = []
config: dict[str, int | str] = {}
try:
with open(file_path) as config_file:
config = yaml.safe_load(config_file)
except Exception as ex:
logger.error(
"Unable to read the config file at %s. \
Make sure path is correct",
file_path,
)
raise ex
store_boolean_arguments = [
action.dest for action in self._actions if isinstance(action, StoreBoolean)
]
for key, value in config.items():
if isinstance(value, bool) and key not in store_boolean_arguments:
if value:
processed_args.append("--" + key)
elif isinstance(value, list):
if value:
processed_args.append("--" + key)
for item in value:
processed_args.append(str(item))
else:
processed_args.append("--" + key)
processed_args.append(str(value))
return processed_args
class AtomicCounter:
"""An atomic, thread-safe counter"""
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Argument parsing utilities for vLLM."""
import json
import sys
import textwrap
from argparse import (
Action,
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from collections import defaultdict
from typing import TYPE_CHECKING, Any
import regex as re
import yaml
from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
else:
Namespace = object
logger = init_logger(__name__)
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(
f"Invalid boolean value: {values}. Expected 'true' or 'false'."
)
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile(r"\n{2,}\s*")
text = single_newline.sub(" ", text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions)
class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
_deprecated: set[Action] = set()
_json_tip: str = (
"When passing JSON CLI arguments, the following sets of arguments "
"are equivalent:\n"
' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
" --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
"Additionally, list elements can be passed individually using +:\n"
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
)
_search_keyword: str | None = None
def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = SortedHelpFormatter
# Pop kwarg "add_json_tip" to control whether to add the JSON tip
self.add_json_tip = kwargs.pop("add_json_tip", True)
super().__init__(*args, **kwargs)
if sys.version_info < (3, 13):
# Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None):
if args is not None and "--disable-log-requests" in args:
# Special case warning because the warning below won't trigger
# if –-disable-log-requests because its value is default.
logger.warning_once(
"argument '--disable-log-requests' is deprecated and "
"replaced with '--enable-log-requests'. This will be "
"removed in v0.12.0."
)
namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated:
if (
hasattr(namespace, dest := action.dest)
and getattr(namespace, dest) != action.default
):
logger.warning_once("argument '%s' is deprecated", dest)
return namespace, args
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
class _FlexibleArgumentGroup(_ArgumentGroup):
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
def add_argument_group(self, *args, **kwargs):
group = self._FlexibleArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def format_help(self):
# Only use custom help formatting for bottom level parsers
if self._subparsers is not None:
return super().format_help()
formatter = self._get_formatter()
# Handle keyword search of the args
if (search_keyword := self._search_keyword) is not None:
# Normalise the search keyword
search_keyword = search_keyword.lower().replace("_", "-")
# Return full help if searching for 'all'
if search_keyword == "all":
self.epilog = self._json_tip
return super().format_help()
# Return group help if searching for a group title
for group in self._action_groups:
if group.title and group.title.lower() == search_keyword:
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# Return matched args if searching for an arg name
matched_actions = []
for group in self._action_groups:
for action in group._group_actions:
# search option name
if any(
search_keyword in opt.lower() for opt in action.option_strings
):
matched_actions.append(action)
if matched_actions:
formatter.start_section(f"Arguments matching '{search_keyword}'")
formatter.add_arguments(matched_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# No match found
formatter.add_text(
f"No group or arguments matching '{search_keyword}'.\n"
"Use '--help' to see available groups or "
"'--help=all' to see all available parameters."
)
return formatter.format_help()
# usage
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
# description
formatter.add_text(self.description)
# positionals, optionals and user-defined groups
formatter.start_section("Config Groups")
config_groups = ""
for group in self._action_groups:
if not group._group_actions:
continue
title = group.title
description = group.description or ""
config_groups += f"{title: <24}{description}\n"
formatter.add_text(config_groups)
formatter.end_section()
# epilog
formatter.add_text(self.epilog)
# determine help from format above
return formatter.format_help()
def parse_args( # type: ignore[override]
self,
args: list[str] | None = None,
namespace: Namespace | None = None,
):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
try:
model_idx = next(
i
for i, arg in enumerate(args)
if arg == "--model" or arg.startswith("--model=")
)
logger.warning(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option. "
"The `--model` option will be removed in v0.13."
)
if args[model_idx] == "--model":
model_tag = args[model_idx + 1]
rest_start_idx = model_idx + 2
else:
model_tag = args[model_idx].removeprefix("--model=")
rest_start_idx = model_idx + 1
# Move <model> to the front, e,g:
# [Before]
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
# [After]
# vllm serve <model> -tp 2 --enforce-eager --port 8001
args = [
"serve",
model_tag,
*args[1:model_idx],
*args[rest_start_idx:],
]
except StopIteration:
pass
if "--config" in args:
args = self._pull_args_from_config(args)
def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace("_", "-")
# Everything between the first -- and the first .
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names
processed_args = list[str]()
for i, arg in enumerate(args):
if arg.startswith("--help="):
FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
processed_args.append("--help")
elif arg.startswith("--"):
if "=" in arg:
key, value = arg.split("=", 1)
key = pattern.sub(repl, key, count=1)
processed_args.append(f"{key}={value}")
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
else:
processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
`{"a": {"b": {"c": 1}}}`
"""
nested_dict: Any = value
for key in reversed(keys):
nested_dict = {key: nested_dict}
return nested_dict
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
) -> set[str]:
"""Recursively updates a dictionary with another dictionary.
Returns a set of duplicate keys that were overwritten.
"""
duplicates = set[str]()
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
nested_duplicates = recursive_dict_update(original[k], v)
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
elif isinstance(v, list) and isinstance(original.get(k), list):
original[k] += v
else:
if k in original:
duplicates.add(k)
original[k] = v
return duplicates
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, '.' was only in the value
continue
else:
value_str = processed_args[i + 1]
delete.add(i + 1)
if processed_arg.endswith("+"):
processed_arg = processed_arg[:-1]
value_str = json.dumps(list(value_str.split(",")))
key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
duplicates |= {f"{key}.{d}" for d in arg_duplicates}
delete.add(i)
# Filter out the dict args we set to None
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
if duplicates:
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
processed_args.append(json.dumps(dict_value))
return super().parse_args(processed_args, namespace)
def check_port(self, value):
try:
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value
def _pull_args_from_config(self, args: list[str]) -> list[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
The arguments in config file will be inserted between
the argument list.
example:
```yaml
port: 12323
tensor-parallel-size: 4
```
```python
$: vllm {serve,chat,complete} "facebook/opt-12B" \
--config config.yaml -tp 2
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--config', 'config.yaml',
'-tp', '2'
]
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
'-tp', '2'
]
```
Please note how the config args are inserted after the sub command.
this way the order of priorities is maintained when these are args
parsed by super().
"""
assert args.count("--config") <= 1, "More than one config file specified!"
index = args.index("--config")
if index == len(args) - 1:
raise ValueError(
"No config file specified! \
Please check your command-line arguments."
)
file_path = args[index + 1]
config_args = self.load_config_file(file_path)
# 0th index might be the sub command {serve,chat,complete,...}
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0].startswith("-"):
# No sub command (e.g., api_server entry point)
args = config_args + args[0:index] + args[index + 2 :]
elif args[0] == "serve":
model_in_cli = len(args) > 1 and not args[1].startswith("-")
model_in_config = any(arg == "--model" for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model specified! Please specify model either "
"as a positional argument or in a config file."
)
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = (
[args[0]]
+ [args[1]]
+ config_args
+ args[2:index]
+ args[index + 2 :]
)
else:
# No model in CLI, use config if available
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
return args
def load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
port: 12323
tensor-parallel-size: 4
```
returns:
processed_args: list[str] = [
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split(".")[-1]
if extension not in ("yaml", "yml"):
raise ValueError(
f"Config file must be of a yaml/yml type. {extension} supplied"
)
# only expecting a flat dictionary of atomic types
processed_args: list[str] = []
config: dict[str, int | str] = {}
try:
with open(file_path) as config_file:
config = yaml.safe_load(config_file)
except Exception as ex:
logger.error(
"Unable to read the config file at %s. Check path correctness",
file_path,
)
raise ex
store_boolean_arguments = [
action.dest for action in self._actions if isinstance(action, StoreBoolean)
]
for key, value in config.items():
if isinstance(value, bool) and key not in store_boolean_arguments:
if value:
processed_args.append("--" + key)
elif isinstance(value, list):
if value:
processed_args.append("--" + key)
for item in value:
processed_args.append(str(item))
else:
processed_args.append("--" + key)
processed_args.append(str(value))
return processed_args
......@@ -16,8 +16,8 @@ import torch
import vllm.envs as envs
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
@functools.cache
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Math utility functions for vLLM."""
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def next_power_of_2(n: int) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()
def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)
def round_up(x: int, y: int) -> int:
"""Round up x to the nearest multiple of y."""
return ((x + y - 1) // y) * y
def round_down(x: int, y: int) -> int:
"""Round down x to the nearest multiple of y."""
return (x // y) * y
......@@ -37,7 +37,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
......
......@@ -34,12 +34,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.flashinfer import (
can_use_trtllm_attention,
flashinfer_disable_q_quantization,
use_trtllm_attention,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
......
......@@ -28,7 +28,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
......
......@@ -7,7 +7,7 @@ import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
......
......@@ -220,8 +220,8 @@ from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
......
......@@ -22,7 +22,7 @@ from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
......
......@@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionLayer
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
......
......@@ -13,7 +13,7 @@ from vllm.attention.backends.abstract import (
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2
from vllm.utils.math_utils import cdiv, next_power_of_2
logger = init_logger(__name__)
......
......@@ -21,7 +21,7 @@ import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
......
......@@ -12,8 +12,8 @@ from typing import Any, NewType, TypeAlias
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.utils.hashing import sha256_cbor
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_constants import GiB_bytes
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
......
......@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Sequence
from vllm.utils import cdiv
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (
......
......@@ -29,10 +29,11 @@ from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.utils import Device
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
from vllm.utils.func_utils import deprecate_kwargs
from vllm.utils.math_utils import cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment