Unverified Commit 61e632ae authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Turn `@config` into a `dataclass_transform` (#31541)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent b1bb18de
......@@ -3,8 +3,6 @@
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
......@@ -19,7 +17,6 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
......
......@@ -5,7 +5,6 @@ import os
from typing import Any, Literal
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
......@@ -32,7 +31,6 @@ def _is_uri_path(path: str) -> bool:
@config
@dataclass
class ProfilerConfig:
"""Dataclass which contains profiler config for the engine."""
......
......@@ -6,7 +6,6 @@ from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
......@@ -24,7 +23,6 @@ SchedulerPolicy = Literal["fcfs", "priority"]
@config
@dataclass
class SchedulerConfig:
"""Scheduler configuration."""
......
......@@ -5,7 +5,6 @@ import ast
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.model import ModelConfig
......@@ -55,7 +54,6 @@ SpeculativeMethod = Literal[
@config
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
......
......@@ -2,13 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
......
......@@ -4,7 +4,6 @@
from typing import Any, Literal
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
......@@ -16,7 +15,6 @@ StructuredOutputsBackend = Literal[
@config
@dataclass
class StructuredOutputsConfig:
"""Dataclass which contains structured outputs config for the engine."""
......
......@@ -10,14 +10,17 @@ import json
import pathlib
import textwrap
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from dataclasses import MISSING, Field, field, fields, is_dataclass
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
import regex as re
import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable
from typing_extensions import dataclass_transform, runtime_checkable
from vllm.logger import init_logger
......@@ -29,23 +32,39 @@ else:
DataclassInstance = Any
ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=ConfigType)
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
def config(cls: ConfigT) -> ConfigT:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
@dataclass_transform(field_specifiers=(PydanticField,))
def config(
cls: type[ConfigT] | None = None,
*,
config: ConfigDict | None = None,
**kwargs: Any,
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
"""Decorator to create a pydantic dataclass with default config. The default config
for the dataclass forbids extra fields.
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
provided by `get_kwargs` will be
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
`cli_arg` as a JSON string which gets validated by `pydantic`.
All config classes in vLLM should use this decorator.
Config validation is performed by the tools/pre_commit/validate_config.py
script, which is invoked during the pre-commit checks.
"""
return cls
Args:
cls: The class to decorate
config: The pydantic ConfigDict to use. If provided, it will be merged with
the default config.
**kwargs: Additional arguments to pass to pydantic.dataclass."""
# Extra fields are forbidden by default
merged_config = ConfigDict(extra="forbid")
if config is not None:
merged_config.update(config)
def decorator(cls):
return dataclass(cls, config=merged_config, **kwargs)
# Called with arguments: @config(config=...)
if cls is None:
return decorator
# Called without arguments: @config
return decorator(cls)
def get_field(cls: ConfigType, name: str) -> Field:
......@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field:
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields[name]
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
if isinstance(default, FieldInfo):
# Handle pydantic.Field defaults
if default.default_factory is not None:
return field(default_factory=default.default_factory)
else:
default = default.default
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory."
)
try:
named_field = next(f for f in fields(cls) if f.name == name)
except StopIteration as e:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
# The arguments to copy to the new field
default = named_field.default
default_factory = named_field.default_factory
init = named_field.init
# Handle pydantic.Field
if isinstance(default, FieldInfo):
if default.init is not None:
init = default.init
if default.default_factory is not None:
default_factory = cast(Callable[[], Any], default.default_factory)
default = MISSING
else:
default = default.default
if default is MISSING and default_factory is MISSING:
logger.warning_once(
"%s.%s has no default or default factory.", cls.__name__, name
)
return field(default=default, default_factory=default_factory, init=init)
def is_init_field(cls: ConfigType, name: str) -> bool:
return get_field(cls, name).init
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
of `dataclasses.field`"""
cls = type(dataclass_instance)
dataclass_dict = dataclass_instance.__dict__
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
dataclass_dict.update(kwargs)
return cls(**dataclass_dict)
def getattr_iter(
......@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str: ...
......
......@@ -9,7 +9,7 @@ import tempfile
import threading
import time
from contextlib import contextmanager
from dataclasses import is_dataclass, replace
from dataclasses import is_dataclass
from datetime import datetime
from enum import IntEnum
from functools import lru_cache
......@@ -18,10 +18,8 @@ from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
......@@ -41,9 +39,9 @@ from .observability import ObservabilityConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import SpeculativeConfig
from .speculative import EagleModelTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config
from .utils import SupportsHash, config, replace
if TYPE_CHECKING:
from transformers import PretrainedConfig
......@@ -187,8 +185,7 @@ OPTIMIZATION_LEVEL_TO_CONFIG = {
}
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@config(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
......@@ -1395,14 +1392,6 @@ class VllmConfig:
path = self.compilation_config.debug_dump_path / append_path
return path
def replace(self, **kwargs):
"""
Replace attributes of the config, and 'recompute' the config.
dataclass.replace() calls __init__() and __post_init__(), source:
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
"""
return replace(self, **kwargs)
def __str__(self):
return (
f"model={self.model_config.model!r}, "
......
......@@ -13,8 +13,6 @@ from collections.abc import Sequence
from dataclasses import field
from typing import Any, Literal
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
......@@ -69,7 +67,6 @@ class LoRAParserAction(argparse.Action):
@config
@dataclass
class FrontendArgs:
"""Arguments for the OpenAI-compatible frontend server."""
......
......@@ -13,9 +13,10 @@ def register_speculator(name):
@register_speculator("eagle3")
def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None:
"""
Apply Eagle-3 specific configuration transformations.
Apply Eagle-3 specific configuration transformations to the `dict` used to
construct the Transformers PreTrainedConfig.
Eagle-3 specific fields:
- draft_vocab_size: Size of the draft model's vocabulary
......@@ -27,12 +28,14 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
predictions. This is the standard field used in Eagle3 checkpoints.
"""
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"]
pre_trained_config["norm_before_residual"] = config_dict.get(
"norm_before_residual", True
)
pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids"
]
......@@ -24,13 +24,16 @@ class SpeculatorsConfig(PretrainedConfig):
"""Load speculators Eagle config and convert to vLLM format."""
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
vllm_config = cls.extract_vllm_speculative_config(config_dict)
vllm_config = cls.extract_transformers_pre_trained_config(config_dict)
return cls(**vllm_config)
@classmethod
def extract_vllm_speculative_config(
def extract_transformers_pre_trained_config(
cls, config_dict: dict[str, Any]
) -> dict[str, Any]:
"""
Extract standard Transformers PreTrainedConfig config from speculators config.
"""
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
......@@ -38,15 +41,23 @@ class SpeculatorsConfig(PretrainedConfig):
"Please ensure you're loading a speculators-format model."
)
# Start with transformer layer configuration if present
pre_trained_config = config_dict.get("transformer_layer_config", {})
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, pre_trained_config=pre_trained_config)
return pre_trained_config
@classmethod
def extract_vllm_speculative_config(
cls, config_dict: dict[str, Any]
) -> dict[str, Any]:
"""Extract vLLM speculative config from speculators config."""
# validate fields
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return vllm_config
return cls.build_vllm_speculative_config(config_dict=config_dict)
@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
......@@ -101,14 +112,7 @@ class SpeculatorsConfig(PretrainedConfig):
)
# Build base vLLM speculative configuration
vllm_config = {
return {
"method": config_dict.get("speculators_model_type"),
"num_speculative_tokens": num_speculative_tokens,
"target_model": spec_config.get("verifier")["name_or_path"],
}
# Merge transformer layer configuration if present
transformer_config = config_dict.get("transformer_layer_config", {})
vllm_config.update(transformer_config)
return vllm_config
......@@ -4,7 +4,7 @@ from typing import Any
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config import VllmConfig, get_layers_from_vllm_config, replace
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.model_loader import get_model
......@@ -191,10 +191,12 @@ def create_vllm_config_for_draft_model(
old = target_model_vllm_config
assert old.speculative_config is not None, "speculative_config is not set"
old_spec_config = old.speculative_config
new_parallel_config = old_spec_config.draft_parallel_config.replace(
rank=old.parallel_config.rank
new_parallel_config = replace(
old_spec_config.draft_parallel_config,
rank=old.parallel_config.rank,
)
new: VllmConfig = old.replace(
new: VllmConfig = replace(
old,
quant_config=None, # quant_config is recomputed in __init__()
model_config=old_spec_config.draft_model_config,
parallel_config=new_parallel_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment