tensor_schema.py 8.82 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from typing import (Annotated, Any, Optional, Union, get_args, get_origin,
                    get_type_hints)
5
6
7
8
9
10
11
12
13
14

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)


class TensorShape:

15
16
17
18
19
20
21
    def __init__(
        self,
        *dims: Union[int, str],
        dynamic_dims: Optional[set[str]] = None,
    ) -> None:
        super().__init__()

22
23
24
        self.dims = dims
        self.dynamic_dims = dynamic_dims if dynamic_dims else set()

25
26
    def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
        resolved = list[Union[int, str]]()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        for dim in self.dims:
            if isinstance(dim, str) and dim in bindings:
                resolved.append(bindings[dim])
            else:
                resolved.append(dim)
        return tuple(resolved)

    def __str__(self) -> str:
        """Return a string representation of the tensor shape."""
        dim_strs = []
        for dim in self.dims:
            if isinstance(dim, str):
                if dim in self.dynamic_dims:
                    dim_strs.append(
                        f"{dim}*")  # Mark dynamic dimensions with *
                else:
                    dim_strs.append(dim)
            else:
                dim_strs.append(str(dim))
        return f"({', '.join(dim_strs)})"


class TensorSchema:

51
52
53
54
55
56
57
58
59
    def __init__(
        self,
        *,
        validate: bool = True,
        resolve_bindings: Optional[dict[str, int]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()

60
61
62
63
64
65
66
67
        self._resolve_bindings = resolve_bindings if resolve_bindings else {}

        for key, value in kwargs.items():
            setattr(self, key, value)

        if validate:
            self.validate()

68
69
    def __getitem__(self, key: str) -> Any:
        return getattr(self, key)
70

71
72
    def get(self, key: str, default: Any = None) -> Any:
        return getattr(self, key, default)
73

74
75
76
77
78
79
80
    def _match_shape_with_dynamic(
        self,
        actual: tuple[int, ...],
        reference: tuple[int, ...],
        expected_shape: tuple[Union[int, str], ...],
        dynamic_dims: set[str],
    ) -> bool:
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        if len(actual) != len(reference) or len(actual) > len(expected_shape):
            return False

        for i, (a, r) in enumerate(zip(actual, reference)):
            # When validating list inputs, we match shape suffixes only
            # (e.g. "p", 3, "h", "w"), assuming the list length corresponds
            # to the leading symbolic dim (e.g. "bn"). This allows comparing
            # only the trailing dimensions of each element in the list.
            dim = expected_shape[-len(actual) + i]
            # Skip this dimension if it's marked dynamic
            if dim in dynamic_dims:
                continue
            if a != r:
                return False
        return True

97
98
99
100
101
102
103
104
105
106
107
108
109
    def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
        if not idxs:
            return ""

        return str(list(idxs))

    def _validate_field(
            self,
            value: object,
            field_name: str,
            expected_shape: tuple[Union[int, str], ...],
            dynamic_dims: set[str],
            leading_idxs: tuple[int, ...] = (),
110
    ) -> tuple[int, ...]:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        """Validate a field and return the actual shape."""
        if isinstance(value, (int, float)):
            return ()  # Scalar
        if isinstance(value, torch.Tensor):
            return value.shape

        if not isinstance(value, (list, tuple)):
            raise TypeError(
                f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
                f"one of the expected types: int, float, Tensor, list, tuple. "
                f"Got: {type(value)}")

        if len(value) == 0:
            raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
                             f"is an empty sequence")

127
128
129
        # Ensure all tensors in the list have the same
        # shape, besides dynamic dimensions
        for i, v in enumerate(value):
130
131
132
133
134
135
136
137
138
139
140
141
142
            shape = self._validate_field(
                v,
                field_name,
                expected_shape[1:],
                dynamic_dims,
                leading_idxs=leading_idxs + (i, ),
            )

            if i == 0:
                first_shape = shape
            elif not self._match_shape_with_dynamic(
                    shape,
                    first_shape,
143
144
145
                    expected_shape,
                    dynamic_dims,
            ):
146
147
148
149
                raise ValueError(
                    f"{field_name}{self._fmt_indexer(leading_idxs)} "
                    f"contains inconsistent shapes: {first_shape} "
                    f"(index 0) vs {shape} (index {i})")
150
151
152

        # Treat the list as a stacked tensor:
        # shape = (len(list), *tensor.shape)
153
        return (len(value), ) + first_shape
154

155
156
157
158
159
160
161
162
    def _validate_tensor_shape_expected(
        self,
        actual_shape: tuple[int, ...],
        expected_shape: tuple[Union[int, str], ...],
        field_name: str,
        shape_env: dict[str, int],
        dynamic_dims: set[str],
    ) -> None:
163
        """Validate that the actual tensor shape matches the expected shape."""
164

165
166
        if len(actual_shape) != len(expected_shape):
            raise ValueError(f"{field_name} has rank {len(actual_shape)} "
167
168
169
                             f"but expected {len(expected_shape)}. "
                             f"Expected shape: {expected_shape}, "
                             f"but got {actual_shape}")
170
171
172
173
174
175
176

        for i, dim in enumerate(expected_shape):
            if dim in dynamic_dims:
                continue
            elif isinstance(dim, int):
                if actual_shape[i] != dim:
                    raise ValueError(f"{field_name} dim[{i}] expected "
177
178
179
                                     f"{dim}, got {actual_shape[i]}. "
                                     f"Expected shape: {expected_shape}, "
                                     f"but got {actual_shape}")
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            elif isinstance(dim, str):
                if dim in shape_env:
                    if actual_shape[i] != shape_env[dim]:
                        raise ValueError(f"{field_name} dim[{i}] expected "
                                         f"'{dim}'={shape_env[dim]}, got "
                                         f"{actual_shape[i]}")
                else:
                    shape_env[dim] = actual_shape[i]
            else:
                raise TypeError(f"{field_name} dim[{i}] has unsupported "
                                f"type: {type(dim)}")

    def validate(self) -> None:
        type_hints = get_type_hints(self.__class__, include_extras=True)
194
        shape_env = dict[str, int]()
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        for field_name, field_type in type_hints.items():
            # Check if field is missing
            if (not hasattr(self, field_name)
                    or getattr(self, field_name) is None):
                # Check if field is marked as optional
                actual_type = field_type
                if get_origin(field_type) is Annotated:
                    args = get_args(field_type)
                    actual_type = args[0]

                # Check arg was provided as Union
                if get_origin(actual_type) is Union:
                    args = get_args(actual_type)
                    # Skip validation when Union contains None
                    if type(None) in args:
                        continue
212
                # Otherwise field is required, raise error
213
214
215
216
217
218
219
220
221
222
                raise ValueError(f"Required field '{field_name}' is missing")

            # Field exists, proceed with validation
            value = getattr(self, field_name)
            if get_origin(field_type) is not None:
                args = get_args(field_type)

                for arg in args:
                    if isinstance(arg, TensorShape):
                        expected_shape = arg.resolve(**self._resolve_bindings)
223
224
225
226
227
228
                        actual_shape = self._validate_field(
                            value,
                            field_name,
                            expected_shape,
                            arg.dynamic_dims,
                        )
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

                        self._validate_tensor_shape_expected(
                            actual_shape, expected_shape, field_name,
                            shape_env, arg.dynamic_dims)

    def print_shapes(self) -> None:
        """Print TensorShape annotations for debugging."""
        logger.debug("Shapes in %s:", self.__class__.__name__)
        type_hints = get_type_hints(self.__class__, include_extras=True)

        for field_name, field_type in type_hints.items():
            if get_origin(field_type) is not None:
                args = get_args(field_type)
                for arg in args:
                    if isinstance(arg, TensorShape):
                        logger.debug("  %s: %s", field_name, str(arg))