tensor_schema.py 8.56 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from typing import (Annotated, Any, Optional, Union, get_args, get_origin,
                    get_type_hints)
5
6
7
8
9
10
11
12
13
14

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)


class TensorShape:

15
16
17
18
19
20
21
    def __init__(
        self,
        *dims: Union[int, str],
        dynamic_dims: Optional[set[str]] = None,
    ) -> None:
        super().__init__()

22
23
24
        self.dims = dims
        self.dynamic_dims = dynamic_dims if dynamic_dims else set()

25
26
    def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
        resolved = list[Union[int, str]]()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        for dim in self.dims:
            if isinstance(dim, str) and dim in bindings:
                resolved.append(bindings[dim])
            else:
                resolved.append(dim)
        return tuple(resolved)

    def __str__(self) -> str:
        """Return a string representation of the tensor shape."""
        dim_strs = []
        for dim in self.dims:
            if isinstance(dim, str):
                if dim in self.dynamic_dims:
                    dim_strs.append(
                        f"{dim}*")  # Mark dynamic dimensions with *
                else:
                    dim_strs.append(dim)
            else:
                dim_strs.append(str(dim))
        return f"({', '.join(dim_strs)})"


class TensorSchema:

51
52
53
54
55
56
57
58
59
    def __init__(
        self,
        *,
        validate: bool = True,
        resolve_bindings: Optional[dict[str, int]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()

60
61
62
63
64
65
66
67
        self._resolve_bindings = resolve_bindings if resolve_bindings else {}

        for key, value in kwargs.items():
            setattr(self, key, value)

        if validate:
            self.validate()

68
69
    def __getitem__(self, key: str) -> Any:
        return getattr(self, key)
70

71
72
    def get(self, key: str, default: Any = None) -> Any:
        return getattr(self, key, default)
73

74
75
76
77
78
79
80
    def _match_shape_with_dynamic(
        self,
        actual: tuple[int, ...],
        reference: tuple[int, ...],
        expected_shape: tuple[Union[int, str], ...],
        dynamic_dims: set[str],
    ) -> bool:
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        if len(actual) != len(reference) or len(actual) > len(expected_shape):
            return False

        for i, (a, r) in enumerate(zip(actual, reference)):
            # When validating list inputs, we match shape suffixes only
            # (e.g. "p", 3, "h", "w"), assuming the list length corresponds
            # to the leading symbolic dim (e.g. "bn"). This allows comparing
            # only the trailing dimensions of each element in the list.
            dim = expected_shape[-len(actual) + i]
            # Skip this dimension if it's marked dynamic
            if dim in dynamic_dims:
                continue
            if a != r:
                return False
        return True

97
98
99
100
101
102
103
104
105
106
107
108
109
    def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
        if not idxs:
            return ""

        return str(list(idxs))

    def _validate_field(
            self,
            value: object,
            field_name: str,
            expected_shape: tuple[Union[int, str], ...],
            dynamic_dims: set[str],
            leading_idxs: tuple[int, ...] = (),
110
    ) -> tuple[int, ...]:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        """Validate a field and return the actual shape."""
        if isinstance(value, (int, float)):
            return ()  # Scalar
        if isinstance(value, torch.Tensor):
            return value.shape

        if not isinstance(value, (list, tuple)):
            raise TypeError(
                f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
                f"one of the expected types: int, float, Tensor, list, tuple. "
                f"Got: {type(value)}")

        if len(value) == 0:
            raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
                             f"is an empty sequence")

127
128
129
        # Ensure all tensors in the list have the same
        # shape, besides dynamic dimensions
        for i, v in enumerate(value):
130
131
132
133
134
135
136
137
138
139
140
141
142
            shape = self._validate_field(
                v,
                field_name,
                expected_shape[1:],
                dynamic_dims,
                leading_idxs=leading_idxs + (i, ),
            )

            if i == 0:
                first_shape = shape
            elif not self._match_shape_with_dynamic(
                    shape,
                    first_shape,
143
144
145
                    expected_shape,
                    dynamic_dims,
            ):
146
147
148
149
                raise ValueError(
                    f"{field_name}{self._fmt_indexer(leading_idxs)} "
                    f"contains inconsistent shapes: {first_shape} "
                    f"(index 0) vs {shape} (index {i})")
150
151
152

        # Treat the list as a stacked tensor:
        # shape = (len(list), *tensor.shape)
153
        return (len(value), ) + first_shape
154

155
156
157
158
159
160
161
162
    def _validate_tensor_shape_expected(
        self,
        actual_shape: tuple[int, ...],
        expected_shape: tuple[Union[int, str], ...],
        field_name: str,
        shape_env: dict[str, int],
        dynamic_dims: set[str],
    ) -> None:
163
        """Validate that the actual tensor shape matches the expected shape."""
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        if len(actual_shape) != len(expected_shape):
            raise ValueError(f"{field_name} has rank {len(actual_shape)} "
                             f"but expected {len(expected_shape)}")

        for i, dim in enumerate(expected_shape):
            if dim in dynamic_dims:
                continue
            elif isinstance(dim, int):
                if actual_shape[i] != dim:
                    raise ValueError(f"{field_name} dim[{i}] expected "
                                     f"{dim}, got {actual_shape[i]}")
            elif isinstance(dim, str):
                if dim in shape_env:
                    if actual_shape[i] != shape_env[dim]:
                        raise ValueError(f"{field_name} dim[{i}] expected "
                                         f"'{dim}'={shape_env[dim]}, got "
                                         f"{actual_shape[i]}")
                else:
                    shape_env[dim] = actual_shape[i]
            else:
                raise TypeError(f"{field_name} dim[{i}] has unsupported "
                                f"type: {type(dim)}")

    def validate(self) -> None:
        type_hints = get_type_hints(self.__class__, include_extras=True)
190
        shape_env = dict[str, int]()
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        for field_name, field_type in type_hints.items():
            # Check if field is missing
            if (not hasattr(self, field_name)
                    or getattr(self, field_name) is None):
                # Check if field is marked as optional
                actual_type = field_type
                if get_origin(field_type) is Annotated:
                    args = get_args(field_type)
                    actual_type = args[0]

                # Check arg was provided as Union
                if get_origin(actual_type) is Union:
                    args = get_args(actual_type)
                    # Skip validation when Union contains None
                    if type(None) in args:
                        continue
208
                # Otherwise field is required, raise error
209
210
211
212
213
214
215
216
217
218
                raise ValueError(f"Required field '{field_name}' is missing")

            # Field exists, proceed with validation
            value = getattr(self, field_name)
            if get_origin(field_type) is not None:
                args = get_args(field_type)

                for arg in args:
                    if isinstance(arg, TensorShape):
                        expected_shape = arg.resolve(**self._resolve_bindings)
219
220
221
222
223
224
                        actual_shape = self._validate_field(
                            value,
                            field_name,
                            expected_shape,
                            arg.dynamic_dims,
                        )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

                        self._validate_tensor_shape_expected(
                            actual_shape, expected_shape, field_name,
                            shape_env, arg.dynamic_dims)

    def print_shapes(self) -> None:
        """Print TensorShape annotations for debugging."""
        logger.debug("Shapes in %s:", self.__class__.__name__)
        type_hints = get_type_hints(self.__class__, include_extras=True)

        for field_name, field_type in type_hints.items():
            if get_origin(field_type) is not None:
                args = get_args(field_type)
                for arg in args:
                    if isinstance(arg, TensorShape):
                        logger.debug("  %s: %s", field_name, str(arg))