Unverified Commit 61e632ae authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Turn `@config` into a `dataclass_transform` (#31541)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b1bb18de
......@@ -3,11 +3,11 @@
import json
from argparse import ArgumentError
from contextlib import nullcontext
from dataclasses import dataclass, field
from contextlib import AbstractContextManager, nullcontext
from typing import Annotated, Literal
import pytest
from pydantic import Field
from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import (
......@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected):
],
)
def test_literal_to_kwargs(type_hints, expected):
context = nullcontext()
context: AbstractContextManager[object] = nullcontext()
if expected is Exception:
context = pytest.raises(expected)
with context:
......@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected):
@config
@dataclass
class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
......@@ -119,23 +117,23 @@ class DummyConfig:
"""Optional bool with default None"""
optional_literal: Literal["x", "y"] | None = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
"""Tuple with variable length"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
"""Tuple with fixed length"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
"""List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list)
list_literal: list[Literal[1, 2]] = Field(default_factory=list)
"""List with literal choices"""
list_union: list[str | type[object]] = field(default_factory=list)
list_union: list[str | type[object]] = Field(default_factory=list)
"""List with union type"""
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
"""Set with variable length"""
literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
json_tip: dict = Field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
nested_config: NestedConfig = Field(default_factory=NestedConfig)
"""Nested config"""
......@@ -195,7 +193,7 @@ def test_get_kwargs():
json_tip = "Should either be a valid JSON string or JSON keys"
assert json_tip in kwargs["json_tip"]["help"]
# nested config should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]
@pytest.mark.parametrize(
......
......@@ -66,9 +66,6 @@ class _TestConfigFields:
def test_get_field():
with pytest.raises(ValueError):
get_field(_TestConfigFields, "a")
b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field)
assert b.default is MISSING
......@@ -188,7 +185,7 @@ def test_get_pooling_config():
)
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True)
pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)
......
......@@ -7,31 +7,22 @@ import pytest
from tools.pre_commit.validate_config import validate_ast
_TestConfig1 = """
_TestConfig1 = '''
@config
class _TestConfig1:
pass
"""
_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''
_TestConfig3 = """
_TestConfig2 = """
@config
@dataclass
class _TestConfig3:
class _TestConfig2:
a: int = 1
"""
_TestConfig4 = '''
_TestConfig3 = '''
@config
@dataclass
class _TestConfig4:
class _TestConfig3:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''
......@@ -40,10 +31,9 @@ class _TestConfig4:
@pytest.mark.parametrize(
("test_config", "expected_error"),
[
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
(_TestConfig1, "must have a default"),
(_TestConfig2, "must have a docstring"),
(_TestConfig3, "must use a single Literal"),
],
)
def test_config(test_config, expected_error):
......
......@@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"draft_tensor_parallel_size": args.draft_tensor_parallel_size,
"max_num_seqs": 100, # limit cudagraph capture runtime
},
max_num_seqs=100, # limit cudagraph capture runtime
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
......
......@@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated():
# guidance backend. In that case we are in a stopped state, but
# it should be reverted in case EOS is not accepted by the target
# model.
vllm_config = VllmConfig(
decoding_config=StructuredOutputsConfig(
backend="guidance",
)
)
structured_outputs_config = StructuredOutputsConfig(backend="guidance")
vllm_config = VllmConfig(structured_outputs_config=structured_outputs_config)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
backend = GuidanceBackend(
......
......@@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor):
def __init__(self): ...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id
for d in node.decorator_list
if (
isinstance(d, ast.Name)
and ((id := d.id) == "config" or id == "dataclass")
)
or (
isinstance(d, ast.Call)
and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass")
)
]
if set(decorators) == {"config", "dataclass"}:
# Validate classes with a @config decorator
decorators = set()
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
decorator = decorator.func
if isinstance(decorator, ast.Name) and decorator.id == "config":
decorators.add(decorator.id)
if decorators == {"config"}:
validate_class(node)
elif set(decorators) == {"config"}:
fail(f"Class {node.name} with config decorator must be a dataclass.", node)
elif "config" in decorators:
fail(f"config decorator for {node.name} should be used alone", node)
self.generic_visit(node)
......
......@@ -36,6 +36,7 @@ from vllm.config.utils import (
config,
get_attr_docs,
is_init_field,
replace,
update_config,
)
from vllm.config.vllm import (
......@@ -101,6 +102,7 @@ __all__ = [
"config",
"get_attr_docs",
"is_init_field",
"replace",
"update_config",
# From vllm.config.vllm
"VllmConfig",
......
......@@ -4,14 +4,12 @@
from typing import Any, Literal
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@config
@dataclass
class AttentionConfig:
"""Configuration for attention mechanisms in vLLM."""
......
......@@ -6,7 +6,6 @@ from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
......@@ -37,7 +36,6 @@ KVOffloadingBackend = Literal["native", "lmcache"]
@config
@dataclass
class CacheConfig:
"""Configuration for the KV cache."""
......
......@@ -8,8 +8,7 @@ from dataclasses import field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import ConfigDict, Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
from pydantic import Field, TypeAdapter, field_validator
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
......@@ -96,7 +95,6 @@ class CUDAGraphMode(enum.Enum):
@config
@dataclass(config=ConfigDict(extra="forbid"))
class PassConfig:
"""Configuration for custom Inductor passes.
......@@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum):
@config
@dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""
......@@ -311,7 +308,6 @@ class DynamicShapesConfig:
@config
@dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig:
"""Configuration for compilation.
......
......@@ -6,7 +6,6 @@ from typing import Any, Literal
import torch
from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
......@@ -14,8 +13,7 @@ from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
......
......@@ -5,8 +5,6 @@ import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
ECProducer = Literal["ec_producer"]
......@@ -15,7 +13,6 @@ ECRole = Literal[ECProducer, ECConsumer]
@config
@dataclass
class ECTransferConfig:
"""Configuration for distributed EC cache transfer."""
......
......@@ -5,13 +5,11 @@
from typing import Literal
from pydantic import Field
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class KVEventsConfig:
"""Configuration for KV event publishing."""
......
......@@ -5,8 +5,6 @@ import uuid
from dataclasses import field
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
......@@ -16,7 +14,6 @@ KVRole = Literal[KVProducer, KVConsumer]
@config
@dataclass
class KVTransferConfig:
"""Configuration for distributed KV cache transfer."""
......
......@@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
......@@ -21,7 +20,6 @@ logger = init_logger(__name__)
@config
@dataclass
class LoadConfig:
"""Configuration for loading the model weights."""
......
......@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
......@@ -26,8 +25,7 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""
......
......@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.model_arch import (
......@@ -97,8 +96,7 @@ AttnTypeStr = Literal[
]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
"""Configuration for the model."""
......
......@@ -51,7 +51,6 @@ DummyOptions: TypeAlias = (
@config
@dataclass
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
......
......@@ -6,7 +6,6 @@ from typing import Any, Literal, cast
from packaging.version import parse
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
......@@ -16,7 +15,6 @@ DetailedTraceModules = Literal["model", "worker", "all"]
@config
@dataclass
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""
......
......@@ -3,12 +3,10 @@
import os
from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Literal
import torch
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
......@@ -50,7 +48,6 @@ All2AllBackend = Literal[
@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""
......@@ -94,7 +91,6 @@ class EPLBConfig:
@config
@dataclass
class ParallelConfig:
"""Configuration for the distributed execution."""
......@@ -715,6 +711,3 @@ class ParallelConfig:
)
return self
def replace(self, **kwargs) -> Self:
return replace(self, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment