Unverified Commit 61e632ae authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Turn `@config` into a `dataclass_transform` (#31541)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b1bb18de
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
import json import json
from argparse import ArgumentError from argparse import ArgumentError
from contextlib import nullcontext from contextlib import AbstractContextManager, nullcontext
from dataclasses import dataclass, field
from typing import Annotated, Literal from typing import Annotated, Literal
import pytest import pytest
from pydantic import Field
from vllm.config import AttentionConfig, CompilationConfig, config from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import ( from vllm.engine.arg_utils import (
...@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected): ...@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected):
], ],
) )
def test_literal_to_kwargs(type_hints, expected): def test_literal_to_kwargs(type_hints, expected):
context = nullcontext() context: AbstractContextManager[object] = nullcontext()
if expected is Exception: if expected is Exception:
context = pytest.raises(expected) context = pytest.raises(expected)
with context: with context:
...@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected): ...@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected):
@config @config
@dataclass
class NestedConfig: class NestedConfig:
field: int = 1 field: int = 1
"""field""" """field"""
@config @config
@dataclass
class DummyConfig: class DummyConfig:
regular_bool: bool = True regular_bool: bool = True
"""Regular bool with default True""" """Regular bool with default True"""
...@@ -119,23 +117,23 @@ class DummyConfig: ...@@ -119,23 +117,23 @@ class DummyConfig:
"""Optional bool with default None""" """Optional bool with default None"""
optional_literal: Literal["x", "y"] | None = None optional_literal: Literal["x", "y"] | None = None
"""Optional literal with default None""" """Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
"""Tuple with variable length""" """Tuple with variable length"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
"""Tuple with fixed length""" """Tuple with fixed length"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
"""List with variable length""" """List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list) list_literal: list[Literal[1, 2]] = Field(default_factory=list)
"""List with literal choices""" """List with literal choices"""
list_union: list[str | type[object]] = field(default_factory=list) list_union: list[str | type[object]] = Field(default_factory=list)
"""List with union type""" """List with union type"""
set_n: set[int] = field(default_factory=lambda: {1, 2, 3}) set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
"""Set with variable length""" """Set with variable length"""
literal_literal: Literal[Literal[1], Literal[2]] = 1 literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1""" """Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict) json_tip: dict = Field(default_factory=dict)
"""Dict which will be JSON in CLI""" """Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig) nested_config: NestedConfig = Field(default_factory=NestedConfig)
"""Nested config""" """Nested config"""
...@@ -195,7 +193,7 @@ def test_get_kwargs(): ...@@ -195,7 +193,7 @@ def test_get_kwargs():
json_tip = "Should either be a valid JSON string or JSON keys" json_tip = "Should either be a valid JSON string or JSON keys"
assert json_tip in kwargs["json_tip"]["help"] assert json_tip in kwargs["json_tip"]["help"]
# nested config should construct the nested config # nested config should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -66,9 +66,6 @@ class _TestConfigFields: ...@@ -66,9 +66,6 @@ class _TestConfigFields:
def test_get_field(): def test_get_field():
with pytest.raises(ValueError):
get_field(_TestConfigFields, "a")
b = get_field(_TestConfigFields, "b") b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field) assert isinstance(b, Field)
assert b.default is MISSING assert b.default is MISSING
...@@ -188,7 +185,7 @@ def test_get_pooling_config(): ...@@ -188,7 +185,7 @@ def test_get_pooling_config():
) )
def test_get_pooling_config_from_args(): def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2" model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True) pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
model_config = ModelConfig(model_id, pooler_config=pooler_config) model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config) assert asdict(model_config.pooler_config) == asdict(pooler_config)
......
...@@ -7,31 +7,22 @@ import pytest ...@@ -7,31 +7,22 @@ import pytest
from tools.pre_commit.validate_config import validate_ast from tools.pre_commit.validate_config import validate_ast
_TestConfig1 = """ _TestConfig1 = '''
@config @config
class _TestConfig1: class _TestConfig1:
pass
"""
_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int a: int
"""docstring""" """docstring"""
''' '''
_TestConfig3 = """ _TestConfig2 = """
@config @config
@dataclass class _TestConfig2:
class _TestConfig3:
a: int = 1 a: int = 1
""" """
_TestConfig4 = ''' _TestConfig3 = '''
@config @config
@dataclass class _TestConfig3:
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1 a: Union[Literal[1], Literal[2]] = 1
"""docstring""" """docstring"""
''' '''
...@@ -40,10 +31,9 @@ class _TestConfig4: ...@@ -40,10 +31,9 @@ class _TestConfig4:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("test_config", "expected_error"), ("test_config", "expected_error"),
[ [
(_TestConfig1, "must be a dataclass"), (_TestConfig1, "must have a default"),
(_TestConfig2, "must have a default"), (_TestConfig2, "must have a docstring"),
(_TestConfig3, "must have a docstring"), (_TestConfig3, "must use a single Literal"),
(_TestConfig4, "must use a single Literal"),
], ],
) )
def test_config(test_config, expected_error): def test_config(test_config, expected_error):
......
...@@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): ...@@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"max_model_len": args.max_model_len, "max_model_len": args.max_model_len,
"enforce_eager": enforce_eager, "enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size, "draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
}, },
max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len, max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization, gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size, tensor_parallel_size=args.target_tensor_parallel_size,
......
...@@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated(): ...@@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated():
# guidance backend. In that case we are in a stopped state, but # guidance backend. In that case we are in a stopped state, but
# it should be reverted in case EOS is not accepted by the target # it should be reverted in case EOS is not accepted by the target
# model. # model.
vllm_config = VllmConfig( structured_outputs_config = StructuredOutputsConfig(backend="guidance")
decoding_config=StructuredOutputsConfig( vllm_config = VllmConfig(structured_outputs_config=structured_outputs_config)
backend="guidance",
)
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
backend = GuidanceBackend( backend = GuidanceBackend(
......
...@@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor): ...@@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor):
def __init__(self): ... def __init__(self): ...
def visit_ClassDef(self, node): def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators # Validate classes with a @config decorator
decorators = [ decorators = set()
id for decorator in node.decorator_list:
for d in node.decorator_list if isinstance(decorator, ast.Call):
if ( decorator = decorator.func
isinstance(d, ast.Name) if isinstance(decorator, ast.Name) and decorator.id == "config":
and ((id := d.id) == "config" or id == "dataclass") decorators.add(decorator.id)
)
or ( if decorators == {"config"}:
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]
if set(decorators) == {"config", "dataclass"}:
validate_class(node) validate_class(node)
elif set(decorators) == {"config"}: elif "config" in decorators:
fail(f"Class {node.name} with config decorator must be a dataclass.", node) fail(f"config decorator for {node.name} should be used alone", node)
self.generic_visit(node) self.generic_visit(node)
......
...@@ -36,6 +36,7 @@ from vllm.config.utils import ( ...@@ -36,6 +36,7 @@ from vllm.config.utils import (
config, config,
get_attr_docs, get_attr_docs,
is_init_field, is_init_field,
replace,
update_config, update_config,
) )
from vllm.config.vllm import ( from vllm.config.vllm import (
...@@ -101,6 +102,7 @@ __all__ = [ ...@@ -101,6 +102,7 @@ __all__ = [
"config", "config",
"get_attr_docs", "get_attr_docs",
"is_init_field", "is_init_field",
"replace",
"update_config", "update_config",
# From vllm.config.vllm # From vllm.config.vllm
"VllmConfig", "VllmConfig",
......
...@@ -4,14 +4,12 @@ ...@@ -4,14 +4,12 @@
from typing import Any, Literal from typing import Any, Literal
from pydantic import field_validator from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
@config @config
@dataclass
class AttentionConfig: class AttentionConfig:
"""Configuration for attention mechanisms in vLLM.""" """Configuration for attention mechanisms in vLLM."""
......
...@@ -6,7 +6,6 @@ from dataclasses import field ...@@ -6,7 +6,6 @@ from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator from pydantic import Field, SkipValidation, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -37,7 +36,6 @@ KVOffloadingBackend = Literal["native", "lmcache"] ...@@ -37,7 +36,6 @@ KVOffloadingBackend = Literal["native", "lmcache"]
@config @config
@dataclass
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
......
...@@ -8,8 +8,7 @@ from dataclasses import field ...@@ -8,8 +8,7 @@ from dataclasses import field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import ConfigDict, Field, TypeAdapter, field_validator from pydantic import Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
...@@ -96,7 +95,6 @@ class CUDAGraphMode(enum.Enum): ...@@ -96,7 +95,6 @@ class CUDAGraphMode(enum.Enum):
@config @config
@dataclass(config=ConfigDict(extra="forbid"))
class PassConfig: class PassConfig:
"""Configuration for custom Inductor passes. """Configuration for custom Inductor passes.
...@@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum): ...@@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum):
@config @config
@dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig: class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes.""" """Configuration to control/debug torch compile dynamic shapes."""
...@@ -311,7 +308,6 @@ class DynamicShapesConfig: ...@@ -311,7 +308,6 @@ class DynamicShapesConfig:
@config @config
@dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig: class CompilationConfig:
"""Configuration for compilation. """Configuration for compilation.
......
...@@ -6,7 +6,6 @@ from typing import Any, Literal ...@@ -6,7 +6,6 @@ from typing import Any, Literal
import torch import torch
from pydantic import ConfigDict, SkipValidation from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
...@@ -14,8 +13,7 @@ from vllm.utils.hashing import safe_hash ...@@ -14,8 +13,7 @@ from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config @config(config=ConfigDict(arbitrary_types_allowed=True))
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig: class DeviceConfig:
"""Configuration for the device to use for vLLM execution.""" """Configuration for the device to use for vLLM execution."""
......
...@@ -5,8 +5,6 @@ import uuid ...@@ -5,8 +5,6 @@ import uuid
from dataclasses import field from dataclasses import field
from typing import Any, Literal, get_args from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
ECProducer = Literal["ec_producer"] ECProducer = Literal["ec_producer"]
...@@ -15,7 +13,6 @@ ECRole = Literal[ECProducer, ECConsumer] ...@@ -15,7 +13,6 @@ ECRole = Literal[ECProducer, ECConsumer]
@config @config
@dataclass
class ECTransferConfig: class ECTransferConfig:
"""Configuration for distributed EC cache transfer.""" """Configuration for distributed EC cache transfer."""
......
...@@ -5,13 +5,11 @@ ...@@ -5,13 +5,11 @@
from typing import Literal from typing import Literal
from pydantic import Field from pydantic import Field
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
@config @config
@dataclass
class KVEventsConfig: class KVEventsConfig:
"""Configuration for KV event publishing.""" """Configuration for KV event publishing."""
......
...@@ -5,8 +5,6 @@ import uuid ...@@ -5,8 +5,6 @@ import uuid
from dataclasses import field from dataclasses import field
from typing import Any, Literal, get_args from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
...@@ -16,7 +14,6 @@ KVRole = Literal[KVProducer, KVConsumer] ...@@ -16,7 +14,6 @@ KVRole = Literal[KVProducer, KVConsumer]
@config @config
@dataclass
class KVTransferConfig: class KVTransferConfig:
"""Configuration for distributed KV cache transfer.""" """Configuration for distributed KV cache transfer."""
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -21,7 +20,6 @@ logger = init_logger(__name__) ...@@ -21,7 +20,6 @@ logger = init_logger(__name__)
@config @config
@dataclass
class LoadConfig: class LoadConfig:
"""Configuration for loading the model weights.""" """Configuration for loading the model weights."""
......
...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal ...@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal
import torch import torch
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
...@@ -26,8 +25,7 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512] ...@@ -26,8 +25,7 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512] LoRAExtraVocabSize = Literal[256, 512]
@config @config(config=ConfigDict(arbitrary_types_allowed=True))
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig: class LoRAConfig:
"""Configuration for LoRA.""" """Configuration for LoRA."""
......
...@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args ...@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch import torch
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.config.model_arch import ( from vllm.config.model_arch import (
...@@ -97,8 +96,7 @@ AttnTypeStr = Literal[ ...@@ -97,8 +96,7 @@ AttnTypeStr = Literal[
] ]
@config @config(config=ConfigDict(arbitrary_types_allowed=True))
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig: class ModelConfig:
"""Configuration for the model.""" """Configuration for the model."""
......
...@@ -51,7 +51,6 @@ DummyOptions: TypeAlias = ( ...@@ -51,7 +51,6 @@ DummyOptions: TypeAlias = (
@config @config
@dataclass
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
......
...@@ -6,7 +6,6 @@ from typing import Any, Literal, cast ...@@ -6,7 +6,6 @@ from typing import Any, Literal, cast
from packaging.version import parse from packaging.version import parse
from pydantic import Field, field_validator, model_validator from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm import version from vllm import version
from vllm.config.utils import config from vllm.config.utils import config
...@@ -16,7 +15,6 @@ DetailedTraceModules = Literal["model", "worker", "all"] ...@@ -16,7 +15,6 @@ DetailedTraceModules = Literal["model", "worker", "all"]
@config @config
@dataclass
class ObservabilityConfig: class ObservabilityConfig:
"""Configuration for observability - metrics and tracing.""" """Configuration for observability - metrics and tracing."""
......
...@@ -3,12 +3,10 @@ ...@@ -3,12 +3,10 @@
import os import os
from collections.abc import Callable from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
import torch import torch
from pydantic import Field, field_validator, model_validator from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self from typing_extensions import Self
...@@ -50,7 +48,6 @@ All2AllBackend = Literal[ ...@@ -50,7 +48,6 @@ All2AllBackend = Literal[
@config @config
@dataclass
class EPLBConfig: class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP).""" """Configuration for Expert Parallel Load Balancing (EP)."""
...@@ -94,7 +91,6 @@ class EPLBConfig: ...@@ -94,7 +91,6 @@ class EPLBConfig:
@config @config
@dataclass
class ParallelConfig: class ParallelConfig:
"""Configuration for the distributed execution.""" """Configuration for the distributed execution."""
...@@ -715,6 +711,3 @@ class ParallelConfig: ...@@ -715,6 +711,3 @@ class ParallelConfig:
) )
return self return self
def replace(self, **kwargs) -> Self:
return replace(self, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment