Unverified Commit 4b2ed792 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Improve configs - the rest! (#17562)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 7e357113
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
...@@ -95,9 +95,6 @@ def test_full_graph( ...@@ -95,9 +95,6 @@ def test_full_graph(
run_model(optimization_level, model, model_kwargs) run_model(optimization_level, model, model_kwargs)
PassConfig = CompilationConfig.PassConfig
# TODO(luka) add other supported compilation config scenarios here # TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize( @pytest.mark.parametrize(
"compilation_config, model_info", "compilation_config, model_info",
......
...@@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, ...@@ -11,7 +11,7 @@ from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym) kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from .backend import TestBackend from .backend import TestBackend
...@@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, ...@@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda") torch.set_default_device("cuda")
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \ vllm_config.compilation_config = CompilationConfig(
CompilationConfig.PassConfig(enable_fusion=do_fusion, pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config) fusion_pass = FusionPass.instance(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
......
...@@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, ...@@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey) FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
...@@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, ...@@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config = VllmConfig(compilation_config=CompilationConfig( vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \ vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True, PassConfig(enable_fusion=True, enable_noop=True)
enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
......
...@@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, ...@@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn_maybe, is_func) find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig) PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment, from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
...@@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, ...@@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass # configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig( vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
pass_config=CompilationConfig.PassConfig( enable_sequence_parallelism=True))
enable_sequence_parallelism=True, ), )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config # this is a fake model name to construct the model config
......
...@@ -6,7 +6,7 @@ import vllm.envs as envs ...@@ -6,7 +6,7 @@ import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from .backend import TestBackend from .backend import TestBackend
...@@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): ...@@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
# Reshape pass is needed for the fusion pass to work # Reshape pass is needed for the fusion pass to work
config = VllmConfig() config = VllmConfig()
config.compilation_config = CompilationConfig( config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(enable_fusion=True, pass_config=PassConfig(enable_fusion=True, enable_reshape=True))
enable_reshape=True))
fusion_pass = ActivationQuantFusionPass(config) fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass) backend = TestBackend(fusion_pass)
......
...@@ -206,7 +206,7 @@ def _compare_sp( ...@@ -206,7 +206,7 @@ def _compare_sp(
'compile_sizes': [4, 8], 'compile_sizes': [4, 8],
'splitting_ops': [], 'splitting_ops': [],
'pass_config': { 'pass_config': {
'enable_sequence_parallism': sp_enabled, 'enable_sequence_parallelism': sp_enabled,
'enable_noop': True, 'enable_noop': True,
'enable_fusion': True, 'enable_fusion': True,
}, },
...@@ -223,7 +223,7 @@ def _compare_sp( ...@@ -223,7 +223,7 @@ def _compare_sp(
"--distributed-executor-backend", "--distributed-executor-backend",
distributed_backend, distributed_backend,
"--compilation_config", "--compilation_config",
str(compilation_config), json.dumps(compilation_config),
] ]
tp_env = { tp_env = {
......
...@@ -8,21 +8,18 @@ from typing import Literal, Optional ...@@ -8,21 +8,18 @@ from typing import Literal, Optional
import pytest import pytest
from vllm.config import config from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type, get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs, literal_to_kwargs, nullable_kvs,
optional_type) optional_type, parse_type)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [ @pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42), (int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14), (float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"), (str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', { (json.loads, '{"foo":1,"bar":2}', {
"foo": 1, "foo": 1,
"bar": 2 "bar": 2
...@@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser ...@@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
"foo": 1, "foo": 1,
"bar": 2 "bar": 2
}), }),
(json.loads, "None", None),
]) ])
def test_optional_type(type, value, expected): def test_parse_type(type, value, expected):
optional_type_func = optional_type(type) parse_type_func = parse_type(type)
context = nullcontext() context = nullcontext()
if value == "foo=1,bar=2": if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning) context = pytest.warns(DeprecationWarning)
with context: with context:
assert optional_type_func(value) == expected assert parse_type_func(value) == expected
def test_optional_type():
optional_type_func = optional_type(int)
assert optional_type_func("None") is None
assert optional_type_func("42") == 42
@pytest.mark.parametrize(("type_hint", "type", "expected"), [ @pytest.mark.parametrize(("type_hint", "type", "expected"), [
...@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected): ...@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
@config @config
@dataclass @dataclass
class DummyConfigClass: class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class FromCliConfig1:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 1
return inst
@config
@dataclass
class FromCliConfig2:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 2
return inst
@config
@dataclass
class DummyConfig:
regular_bool: bool = True regular_bool: bool = True
"""Regular bool with default True""" """Regular bool with default True"""
optional_bool: Optional[bool] = None optional_bool: Optional[bool] = None
...@@ -108,18 +143,24 @@ class DummyConfigClass: ...@@ -108,18 +143,24 @@ class DummyConfigClass:
"""Literal of literals with default 1""" """Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict) json_tip: dict = field(default_factory=dict)
"""Dict which will be JSON in CLI""" """Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
"""Nested config"""
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
"""Config with from_cli method"""
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
"""Different config with from_cli method"""
@pytest.mark.parametrize(("type_hint", "expected"), [ @pytest.mark.parametrize(("type_hint", "expected"), [
(int, False), (int, False),
(DummyConfigClass, True), (DummyConfig, True),
]) ])
def test_is_not_builtin(type_hint, expected): def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected assert is_not_builtin(type_hint) == expected
def test_get_kwargs(): def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass) kwargs = get_kwargs(DummyConfig)
print(kwargs) print(kwargs)
# bools should not have their type set # bools should not have their type set
...@@ -142,6 +183,11 @@ def test_get_kwargs(): ...@@ -142,6 +183,11 @@ def test_get_kwargs():
# dict should have json tip in help # dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string." json_tip = "\n\nShould be a valid JSON string."
assert kwargs["json_tip"]["help"].endswith(json_tip) assert kwargs["json_tip"]["help"].endswith(json_tip)
# nested config should should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
# from_cli configs should be constructed with the correct method
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(("arg", "expected"), [ @pytest.mark.parametrize(("arg", "expected"), [
...@@ -177,7 +223,7 @@ def test_compilation_config(): ...@@ -177,7 +223,7 @@ def test_compilation_config():
# default value # default value
args = parser.parse_args([]) args = parser.parse_args([])
assert args.compilation_config is None assert args.compilation_config == CompilationConfig()
# set to O3 # set to O3
args = parser.parse_args(["-O3"]) args = parser.parse_args(["-O3"])
...@@ -194,7 +240,7 @@ def test_compilation_config(): ...@@ -194,7 +240,7 @@ def test_compilation_config():
# set to string form of a dict # set to string form of a dict
args = parser.parse_args([ args = parser.parse_args([
"--compilation-config", "--compilation-config",
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
]) ])
assert (args.compilation_config.level == 3 and assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
...@@ -202,7 +248,7 @@ def test_compilation_config(): ...@@ -202,7 +248,7 @@ def test_compilation_config():
# set to string form of a dict # set to string form of a dict
args = parser.parse_args([ args = parser.parse_args([
"--compilation-config=" "--compilation-config="
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
]) ])
assert (args.compilation_config.level == 3 and assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import torch import torch
from vllm.config import CompilationConfig, VllmConfig from vllm.config import PassConfig, VllmConfig
# yapf: disable # yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import ( from vllm.distributed import (
...@@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass): ...@@ -56,10 +56,7 @@ class VllmInductorPass(InductorPass):
class PrinterInductorPass(VllmInductorPass): class PrinterInductorPass(VllmInductorPass):
def __init__(self, def __init__(self, name: str, config: PassConfig, always=False):
name: str,
config: CompilationConfig.PassConfig,
always=False):
super().__init__(config) super().__init__(config)
self.name = name self.name = name
self.always = always self.always = always
......
This diff is collapsed.
...@@ -5,6 +5,7 @@ import threading ...@@ -5,6 +5,7 @@ import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import deque from collections import deque
from dataclasses import asdict
from itertools import count from itertools import count
from queue import Queue from queue import Queue
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
...@@ -284,7 +285,7 @@ class EventPublisherFactory: ...@@ -284,7 +285,7 @@ class EventPublisherFactory:
if not config: if not config:
return NullEventPublisher() return NullEventPublisher()
config_dict = config.model_dump() config_dict = asdict(config)
kind = config_dict.pop("publisher", "null") kind = config_dict.pop("publisher", "null")
config_dict.pop("enable_kv_cache_events") config_dict.pop("enable_kv_cache_events")
......
...@@ -7,10 +7,10 @@ import json ...@@ -7,10 +7,10 @@ import json
import re import re
import threading import threading
import warnings import warnings
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations from itertools import permutations
from typing import (Any, Callable, Dict, List, Literal, Optional, Type, from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
TypeVar, Union, cast, get_args, get_origin) Type, TypeVar, Union, cast, get_args, get_origin)
import torch import torch
from typing_extensions import TypeIs, deprecated from typing_extensions import TypeIs, deprecated
...@@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager ...@@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
is_in_ray_actor)
# yapf: enable # yapf: enable
...@@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object] ...@@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object] TypeHintT = Union[type[T], object]
def optional_type( def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]: def _parse_type(val: str) -> T:
if val == "" or val == "None":
return None
try: try:
if return_type is json.loads and not re.match("^{.*}$", val): if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val)) return cast(T, nullable_kvs(val))
...@@ -62,13 +60,23 @@ def optional_type( ...@@ -62,13 +60,23 @@ def optional_type(
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e f"Value {val} cannot be converted to {return_type}.") from e
return _parse_type
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
return parse_type(return_type)(val)
return _optional_type return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val): if not re.match("^{.*}$", val):
return str(val) return str(val)
else:
return optional_type(json.loads)(val) return optional_type(json.loads)(val)
...@@ -144,9 +152,24 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -144,9 +152,24 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls) cls_docs = get_attr_docs(cls)
kwargs = {} kwargs = {}
for field in fields(cls): for field in fields(cls):
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))
dataclass_cls = next(generator, None)
# Get the default value of the field # Get the default value of the field
if field.default is not MISSING:
default = field.default default = field.default
if field.default_factory is not MISSING: elif field.default_factory is not MISSING:
if is_dataclass(field.default_factory) and is_in_doc_build():
default = {}
else:
default = field.default_factory() default = field.default_factory()
# Get the help text for the field # Get the help text for the field
...@@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Initialise the kwargs dictionary for the field # Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help} kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints # Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string." json_tip = "\n\nShould be a valid JSON string."
if contains_type(type_hints, bool): if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method
if hasattr(dataclass_cls, "from_cli"):
from_cli = dataclass_cls.from_cli
dataclass_init = lambda x, f=from_cli: f(x)
kwargs[name]["type"] = dataclass_init
kwargs[name]["help"] += json_tip
elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags # Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal): elif contains_type(type_hints, Literal):
...@@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = union_dict_and_str kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict): elif contains_type(type_hints, dict):
# Dict arguments will always be optional # Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads) kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str) elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)): or any(is_not_builtin(th) for th in type_hints)):
...@@ -771,63 +795,20 @@ class EngineArgs: ...@@ -771,63 +795,20 @@ class EngineArgs:
scheduler_group.add_argument("--scheduler-cls", scheduler_group.add_argument("--scheduler-cls",
**scheduler_kwargs["scheduler_cls"]) **scheduler_kwargs["scheduler_cls"])
# Compilation arguments
# compilation_kwargs = get_kwargs(CompilationConfig)
compilation_group = parser.add_argument_group(
title="CompilationConfig",
description=CompilationConfig.__doc__,
)
compilation_group.add_argument(
"--compilation-config",
"-O",
type=CompilationConfig.from_cli,
default=None,
help="torch.compile configuration for the model. "
"When it is a number (0, 1, 2, 3), it will be "
"interpreted as the optimization level.\n"
"NOTE: level 0 is the default level without "
"any optimization. level 1 and 2 are for internal "
"testing only. level 3 is the recommended level "
"for production.\n"
"To specify the full compilation config, "
"use a JSON string, e.g. ``{\"level\": 3, "
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
"Following the convention of traditional "
"compilers, using ``-O`` without space is also "
"supported. ``-O3`` is equivalent to ``-O 3``.")
# KVTransfer arguments
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
kv_transfer_group = parser.add_argument_group(
title="KVTransferConfig",
description=KVTransferConfig.__doc__,
)
kv_transfer_group.add_argument(
"--kv-transfer-config",
type=KVTransferConfig.from_cli,
default=None,
help="The configurations for distributed KV cache "
"transfer. Should be a JSON string.")
kv_transfer_group.add_argument(
'--kv-events-config',
type=KVEventsConfig.from_cli,
default=None,
help='The configurations for event publishing.')
# vLLM arguments # vLLM arguments
# vllm_kwargs = get_kwargs(VllmConfig) vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group( vllm_group = parser.add_argument_group(
title="VllmConfig", title="VllmConfig",
description=VllmConfig.__doc__, description=VllmConfig.__doc__,
) )
vllm_group.add_argument( vllm_group.add_argument("--kv-transfer-config",
"--additional-config", **vllm_kwargs["kv_transfer_config"])
type=json.loads, vllm_group.add_argument('--kv-events-config',
default=None, **vllm_kwargs["kv_events_config"])
help="Additional config for specified platform in JSON format. " vllm_group.add_argument("--compilation-config", "-O",
"Different platforms may support different configs. Make sure the " **vllm_kwargs["compilation_config"])
"configs are valid for the platform you are using. The input format" vllm_group.add_argument("--additional-config",
" is like '{\"config_key\":\"config_value\"}'") **vllm_kwargs["additional_config"])
# Other arguments # Other arguments
parser.add_argument('--use-v2-block-manager', parser.add_argument('--use-v2-block-manager',
......
...@@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated ...@@ -13,7 +13,8 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig, ModelDType, TokenizerMode from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
is_init_field)
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption) TaskOption)
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
...@@ -204,9 +205,13 @@ class LLM: ...@@ -204,9 +205,13 @@ class LLM:
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None: if compilation_config is not None:
if isinstance(compilation_config, (int, dict)): if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig.from_cli( compilation_config_instance = CompilationConfig(
str(compilation_config)) level=compilation_config)
elif isinstance(compilation_config, dict):
predicate = lambda x: is_init_field(CompilationConfig, x[0])
compilation_config_instance = CompilationConfig(
**dict(filter(predicate, compilation_config.items())))
else: else:
compilation_config_instance = compilation_config compilation_config_instance = compilation_config
else: else:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Tuple, Union from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
import torch import torch
from tpu_info import device from tpu_info import device
...@@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType ...@@ -13,9 +13,10 @@ from vllm.sampling_params import SamplingParams, SamplingType
from .interface import Platform, PlatformEnum, _Backend from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
else: else:
BlockSize = None
ModelConfig = None ModelConfig = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
...@@ -94,7 +95,7 @@ class TpuPlatform(Platform): ...@@ -94,7 +95,7 @@ class TpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
# For v0, the default block size is 16. # For v0, the default block size is 16.
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 cache_config.block_size = cast(BlockSize, 16)
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
# TPU only supports DYNAMO_ONCE compilation level # TPU only supports DYNAMO_ONCE compilation level
...@@ -118,7 +119,7 @@ class TpuPlatform(Platform): ...@@ -118,7 +119,7 @@ class TpuPlatform(Platform):
from vllm.v1.attention.backends.pallas import ( from vllm.v1.attention.backends.pallas import (
PallasAttentionBackend) PallasAttentionBackend)
cache_config.block_size = PallasAttentionBackend.get_page_size( cache_config.block_size = PallasAttentionBackend.get_page_size(
vllm_config) vllm_config) # type: ignore[assignment]
min_page_size = PallasAttentionBackend.get_min_page_size( min_page_size = PallasAttentionBackend.get_min_page_size(
vllm_config) vllm_config)
if min_page_size > cache_config.block_size: if min_page_size > cache_config.block_size:
...@@ -128,7 +129,7 @@ class TpuPlatform(Platform): ...@@ -128,7 +129,7 @@ class TpuPlatform(Platform):
cache_config.block_size, cache_config.block_size,
min_page_size, min_page_size,
) )
cache_config.block_size = min_page_size cache_config.block_size = min_page_size # type: ignore[assignment]
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
......
...@@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: ...@@ -1820,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
return isinstance(zmq, _MockModule)
except ModuleNotFoundError:
return False
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
""" """
Import a Python file according to its file path. Import a Python file according to its file path.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment