Unverified Commit 61e632ae authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Turn `@config` into a `dataclass_transform` (#31541)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b1bb18de
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
from typing import Any, Literal, get_args from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
...@@ -19,7 +17,6 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType) ...@@ -19,7 +17,6 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config @config
@dataclass
class PoolerConfig: class PoolerConfig:
"""Controls the behavior of output pooling in pooling models.""" """Controls the behavior of output pooling in pooling models."""
......
...@@ -5,7 +5,6 @@ import os ...@@ -5,7 +5,6 @@ import os
from typing import Any, Literal from typing import Any, Literal
from pydantic import Field, model_validator from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
...@@ -32,7 +31,6 @@ def _is_uri_path(path: str) -> bool: ...@@ -32,7 +31,6 @@ def _is_uri_path(path: str) -> bool:
@config @config
@dataclass
class ProfilerConfig: class ProfilerConfig:
"""Dataclass which contains profiler config for the engine.""" """Dataclass which contains profiler config for the engine."""
......
...@@ -6,7 +6,6 @@ from dataclasses import InitVar ...@@ -6,7 +6,6 @@ from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
...@@ -24,7 +23,6 @@ SchedulerPolicy = Literal["fcfs", "priority"] ...@@ -24,7 +23,6 @@ SchedulerPolicy = Literal["fcfs", "priority"]
@config @config
@dataclass
class SchedulerConfig: class SchedulerConfig:
"""Scheduler configuration.""" """Scheduler configuration."""
......
...@@ -5,7 +5,6 @@ import ast ...@@ -5,7 +5,6 @@ import ast
from typing import TYPE_CHECKING, Any, Literal, get_args from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
...@@ -55,7 +54,6 @@ SpeculativeMethod = Literal[ ...@@ -55,7 +54,6 @@ SpeculativeMethod = Literal[
@config @config
@dataclass
class SpeculativeConfig: class SpeculativeConfig:
"""Configuration for speculative decoding.""" """Configuration for speculative decoding."""
......
...@@ -2,13 +2,10 @@ ...@@ -2,13 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
@config @config
@dataclass
class SpeechToTextConfig: class SpeechToTextConfig:
"""Configuration for speech-to-text models.""" """Configuration for speech-to-text models."""
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from typing import Any, Literal from typing import Any, Literal
from pydantic import model_validator from pydantic import model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config
...@@ -16,7 +15,6 @@ StructuredOutputsBackend = Literal[ ...@@ -16,7 +15,6 @@ StructuredOutputsBackend = Literal[
@config @config
@dataclass
class StructuredOutputsConfig: class StructuredOutputsConfig:
"""Dataclass which contains structured outputs config for the engine.""" """Dataclass which contains structured outputs config for the engine."""
......
...@@ -10,14 +10,17 @@ import json ...@@ -10,14 +10,17 @@ import json
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Callable, Mapping, Sequence, Set from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, field, fields, is_dataclass
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
import regex as re import regex as re
import torch import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable from typing_extensions import dataclass_transform, runtime_checkable
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -29,23 +32,39 @@ else: ...@@ -29,23 +32,39 @@ else:
DataclassInstance = Any DataclassInstance = Any
ConfigType = type[DataclassInstance] ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=ConfigType) ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
def config(cls: ConfigT) -> ConfigT: @dataclass_transform(field_specifiers=(PydanticField,))
""" def config(
A decorator that ensures all fields in a dataclass have default values cls: type[ConfigT] | None = None,
and that each field has a docstring. *,
config: ConfigDict | None = None,
**kwargs: Any,
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
"""Decorator to create a pydantic dataclass with default config. The default config
for the dataclass forbids extra fields.
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument All config classes in vLLM should use this decorator.
provided by `get_kwargs` will be
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
`cli_arg` as a JSON string which gets validated by `pydantic`.
Config validation is performed by the tools/pre_commit/validate_config.py Args:
script, which is invoked during the pre-commit checks. cls: The class to decorate
""" config: The pydantic ConfigDict to use. If provided, it will be merged with
return cls the default config.
**kwargs: Additional arguments to pass to pydantic.dataclass."""
# Extra fields are forbidden by default
merged_config = ConfigDict(extra="forbid")
if config is not None:
merged_config.update(config)
def decorator(cls):
return dataclass(cls, config=merged_config, **kwargs)
# Called with arguments: @config(config=...)
if cls is None:
return decorator
# Called without arguments: @config
return decorator(cls)
def get_field(cls: ConfigType, name: str) -> Field: def get_field(cls: ConfigType, name: str) -> Field:
...@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field: ...@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field:
default factory fields in `EngineArgs`.""" default factory fields in `EngineArgs`."""
if not is_dataclass(cls): if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.") raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)} try:
if name not in cls_fields: named_field = next(f for f in fields(cls) if f.name == name)
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") except StopIteration as e:
named_field: Field = cls_fields[name] raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory) # The arguments to copy to the new field
if (default := named_field.default) is not MISSING: default = named_field.default
if isinstance(default, FieldInfo): default_factory = named_field.default_factory
# Handle pydantic.Field defaults init = named_field.init
if default.default_factory is not None:
return field(default_factory=default.default_factory) # Handle pydantic.Field
else: if isinstance(default, FieldInfo):
default = default.default if default.init is not None:
return field(default=default) init = default.init
if default.default_factory is not None:
raise ValueError( default_factory = cast(Callable[[], Any], default.default_factory)
f"{cls.__name__}.{name} must have a default value or default factory." default = MISSING
) else:
default = default.default
if default is MISSING and default_factory is MISSING:
logger.warning_once(
"%s.%s has no default or default factory.", cls.__name__, name
)
return field(default=default, default_factory=default_factory, init=init)
def is_init_field(cls: ConfigType, name: str) -> bool:
return get_field(cls, name).init
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
of `dataclasses.field`"""
cls = type(dataclass_instance)
dataclass_dict = dataclass_instance.__dict__
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
dataclass_dict.update(kwargs)
return cls(**dataclass_dict)
def getattr_iter( def getattr_iter(
...@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: ...@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out return out
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
@runtime_checkable @runtime_checkable
class SupportsHash(Protocol): class SupportsHash(Protocol):
def compute_hash(self) -> str: ... def compute_hash(self) -> str: ...
......
...@@ -9,7 +9,7 @@ import tempfile ...@@ -9,7 +9,7 @@ import tempfile
import threading import threading
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import is_dataclass, replace from dataclasses import is_dataclass
from datetime import datetime from datetime import datetime
from enum import IntEnum from enum import IntEnum
from functools import lru_cache from functools import lru_cache
...@@ -18,10 +18,8 @@ from typing import TYPE_CHECKING, Any, TypeVar, get_args ...@@ -18,10 +18,8 @@ from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch import torch
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -41,9 +39,9 @@ from .observability import ObservabilityConfig ...@@ -41,9 +39,9 @@ from .observability import ObservabilityConfig
from .parallel import ParallelConfig from .parallel import ParallelConfig
from .profiler import ProfilerConfig from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig from .speculative import EagleModelTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config from .utils import SupportsHash, config, replace
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -187,8 +185,7 @@ OPTIMIZATION_LEVEL_TO_CONFIG = { ...@@ -187,8 +185,7 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
} }
@config @config(config=ConfigDict(arbitrary_types_allowed=True))
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig: class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This """Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
...@@ -1395,14 +1392,6 @@ class VllmConfig: ...@@ -1395,14 +1392,6 @@ class VllmConfig:
path = self.compilation_config.debug_dump_path / append_path path = self.compilation_config.debug_dump_path / append_path
return path return path
def replace(self, **kwargs):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return replace(self, **kwargs)
def __str__(self): def __str__(self):
return ( return (
f"model={self.model_config.model!r}, " f"model={self.model_config.model!r}, "
......
...@@ -13,8 +13,6 @@ from collections.abc import Sequence ...@@ -13,8 +13,6 @@ from collections.abc import Sequence
from dataclasses import field from dataclasses import field
from typing import Any, Literal from typing import Any, Literal
from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.config import config from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
...@@ -69,7 +67,6 @@ class LoRAParserAction(argparse.Action): ...@@ -69,7 +67,6 @@ class LoRAParserAction(argparse.Action):
@config @config
@dataclass
class FrontendArgs: class FrontendArgs:
"""Arguments for the OpenAI-compatible frontend server.""" """Arguments for the OpenAI-compatible frontend server."""
......
...@@ -13,9 +13,10 @@ def register_speculator(name): ...@@ -13,9 +13,10 @@ def register_speculator(name):
@register_speculator("eagle3") @register_speculator("eagle3")
def update_eagle3(config_dict: dict, vllm_config: dict) -> None: def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
""" """
Apply Eagle-3 specific configuration transformations. Apply Eagle-3 specific configuration transformations to the `dict` used to
construct the Transformers PreTrainedConfig.
Eagle-3 specific fields: Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary - draft_vocab_size: Size of the draft model's vocabulary
...@@ -27,12 +28,14 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: ...@@ -27,12 +28,14 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
predictions. This is the standard field used in Eagle3 checkpoints. predictions. This is the standard field used in Eagle3 checkpoints.
""" """
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None: if config_dict.get("target_hidden_size") is not None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"] pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True) pre_trained_config["norm_before_residual"] = config_dict.get(
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] "norm_before_residual", True
)
pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"): if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids" "eagle_aux_hidden_state_layer_ids"
] ]
...@@ -24,13 +24,16 @@ class SpeculatorsConfig(PretrainedConfig): ...@@ -24,13 +24,16 @@ class SpeculatorsConfig(PretrainedConfig):
"""Load speculators Eagle config and convert to vLLM format.""" """Load speculators Eagle config and convert to vLLM format."""
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
vllm_config = cls.extract_vllm_speculative_config(config_dict) vllm_config = cls.extract_transformers_pre_trained_config(config_dict)
return cls(**vllm_config) return cls(**vllm_config)
@classmethod @classmethod
def extract_vllm_speculative_config( def extract_transformers_pre_trained_config(
cls, config_dict: dict[str, Any] cls, config_dict: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
"""
Extract standard Transformers PreTrainedConfig config from speculators config.
"""
speculators_model_type = config_dict.get("speculators_model_type") speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError( raise ValueError(
...@@ -38,15 +41,23 @@ class SpeculatorsConfig(PretrainedConfig): ...@@ -38,15 +41,23 @@ class SpeculatorsConfig(PretrainedConfig):
"Please ensure you're loading a speculators-format model." "Please ensure you're loading a speculators-format model."
) )
# Start with transformer layer configuration if present
pre_trained_config = config_dict.get("transformer_layer_config", {})
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, pre_trained_config=pre_trained_config)
return pre_trained_config
@classmethod
def extract_vllm_speculative_config(
cls, config_dict: dict[str, Any]
) -> dict[str, Any]:
"""Extract vLLM speculative config from speculators config."""
# validate fields # validate fields
# TODO: @dsikka - use speculators pydantic model to validate # TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict) cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM # Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict) return cls.build_vllm_speculative_config(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return vllm_config
@classmethod @classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
...@@ -101,14 +112,7 @@ class SpeculatorsConfig(PretrainedConfig): ...@@ -101,14 +112,7 @@ class SpeculatorsConfig(PretrainedConfig):
) )
# Build base vLLM speculative configuration # Build base vLLM speculative configuration
vllm_config = { return {
"method": config_dict.get("speculators_model_type"), "method": config_dict.get("speculators_model_type"),
"num_speculative_tokens": num_speculative_tokens, "num_speculative_tokens": num_speculative_tokens,
"target_model": spec_config.get("verifier")["name_or_path"],
} }
# Merge transformer layer configuration if present
transformer_config = config_dict.get("transformer_layer_config", {})
vllm_config.update(transformer_config)
return vllm_config
...@@ -4,7 +4,7 @@ from typing import Any ...@@ -4,7 +4,7 @@ from typing import Any
import torch import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config, replace
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -191,10 +191,12 @@ def create_vllm_config_for_draft_model( ...@@ -191,10 +191,12 @@ def create_vllm_config_for_draft_model(
old = target_model_vllm_config old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set" assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config old_spec_config = old.speculative_config
new_parallel_config = old_spec_config.draft_parallel_config.replace( new_parallel_config = replace(
rank=old.parallel_config.rank old_spec_config.draft_parallel_config,
rank=old.parallel_config.rank,
) )
new: VllmConfig = old.replace( new: VllmConfig = replace(
old,
quant_config=None, # quant_config is recomputed in __init__() quant_config=None, # quant_config is recomputed in __init__()
model_config=old_spec_config.draft_model_config, model_config=old_spec_config.draft_model_config,
parallel_config=new_parallel_config, parallel_config=new_parallel_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment