utils.py 5.97 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Utility functions for vLLM config dataclasses."""
4

5
6
7
import ast
import inspect
import textwrap
8
from collections.abc import Iterable
9
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
10
from itertools import pairwise
11
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
12
13

import regex as re
14
from typing_extensions import runtime_checkable
15
16
17
18

if TYPE_CHECKING:
    from _typeshed import DataclassInstance
else:
19
    DataclassInstance = Any
20

21
ConfigType = type[DataclassInstance]
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
ConfigT = TypeVar("ConfigT", bound=ConfigType)


def config(cls: ConfigT) -> ConfigT:
    """
    A decorator that ensures all fields in a dataclass have default values
    and that each field has a docstring.

    If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
    provided by `get_kwargs` will be
    `pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
    `cli_arg` as a JSON string which gets validated by `pydantic`.

    Config validation is performed by the tools/validate_config.py
    script, which is invoked during the pre-commit checks.
    """
    return cls
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54


def get_field(cls: ConfigType, name: str) -> Field:
    """Get the default factory field of a dataclass by name. Used for getting
    default factory fields in `EngineArgs`."""
    if not is_dataclass(cls):
        raise TypeError("The given class is not a dataclass.")
    cls_fields = {f.name: f for f in fields(cls)}
    if name not in cls_fields:
        raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
    named_field: Field = cls_fields[name]
    if (default_factory := named_field.default_factory) is not MISSING:
        return field(default_factory=default_factory)
    if (default := named_field.default) is not MISSING:
        return field(default=default)
    raise ValueError(
55
56
        f"{cls.__name__}.{name} must have a default value or default factory."
    )
57
58


59
60
61
62
63
64
65
66
67
68
69
70
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
    """
    A helper function that retrieves an attribute from an object which may
    have multiple possible names. This is useful when fetching attributes from
    arbitrary `transformers.PretrainedConfig` instances.
    """
    for name in names:
        if hasattr(object, name):
            return getattr(object, name)
    return default


71
72
73
74
75
76
77
78
79
80
81
82
83
def contains_object_print(text: str) -> bool:
    """
    Check if the text looks like a printed Python object, e.g.
    contains any substring matching the pattern: "at 0xFFFFFFF>"
    We match against 0x followed by 2-16 hex chars (there's
    a max of 16 on a 64-bit system).

    Args:
        text (str): The text to check

    Returns:
        result (bool): `True` if a match is found, `False` otherwise.
    """
84
    pattern = r"at 0x[a-fA-F0-9]{2,16}>"
85
86
87
88
89
90
91
92
93
94
    match = re.search(pattern, text)
    return match is not None


def assert_hashable(text: str) -> bool:
    if not contains_object_print(text):
        return True
    raise AssertionError(
        f"vLLM tried to hash some configs that may have Python objects ids "
        f"in them. This is a bug, please file an issue. "
95
96
        f"Text being hashed: {text}"
    )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125


def get_attr_docs(cls: type[Any]) -> dict[str, str]:
    """
    Get any docstrings placed after attribute assignments in a class body.

    https://davidism.com/mit-license/
    """

    try:
        cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
    except (OSError, KeyError, TypeError):
        # HACK: Python 3.13+ workaround - set missing __firstlineno__
        # Workaround can be removed after we upgrade to pydantic==2.12.0
        with open(inspect.getfile(cls)) as f:
            for i, line in enumerate(f):
                if f"class {cls.__name__}" in line and ":" in line:
                    cls.__firstlineno__ = i + 1
                    break
        cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]

    if not isinstance(cls_node, ast.ClassDef):
        raise TypeError("Given object was not a class.")

    out = {}

    # Consider each pair of nodes.
    for a, b in pairwise(cls_node.body):
        # Must be an assignment then a constant string.
126
127
128
129
130
131
        if (
            not isinstance(a, (ast.Assign, ast.AnnAssign))
            or not isinstance(b, ast.Expr)
            or not isinstance(b.value, ast.Constant)
            or not isinstance(b.value.value, str)
        ):
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            continue

        doc = inspect.cleandoc(b.value.value)

        # An assignment can have multiple targets (a = b = v), but an
        # annotated assignment only has one target.
        targets = a.targets if isinstance(a, ast.Assign) else [a.target]

        for target in targets:
            # Must be assigning to a plain name.
            if not isinstance(target, ast.Name):
                continue

            out[target.id] = doc

    return out


def is_init_field(cls: ConfigType, name: str) -> bool:
    return next(f for f in fields(cls) if f.name == name).init
152
153
154
155


@runtime_checkable
class SupportsHash(Protocol):
156
    def compute_hash(self) -> str: ...
157
158
159


class SupportsMetricsInfo(Protocol):
160
    def metrics_info(self) -> dict[str, str]: ...
161
162
163
164
165


def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
    processed_overrides = {}
    for field_name, value in overrides.items():
166
167
168
        assert hasattr(config, field_name), (
            f"{type(config)} has no field `{field_name}`"
        )
169
170
171
172
        current_value = getattr(config, field_name)
        if is_dataclass(current_value) and not is_dataclass(value):
            assert isinstance(value, dict), (
                f"Overrides to {type(config)}.{field_name} must be a dict"
173
174
                f"  or {type(current_value)}, but got {type(value)}"
            )
175
176
            value = update_config(
                current_value,  # type: ignore[type-var]
177
178
                value,
            )
179
180
        processed_overrides[field_name] = value
    return replace(config, **processed_overrides)