_core_ext.py 8.97 KB
Newer Older
1
2
import importlib.util
from enum import Enum
3
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None


# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
    NONE = 0  # nans are not supported
    IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
    EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s


if TYPE_CHECKING or not core_C_available:
    # On platforms were we cannot use/build the C++ core extension (i.e. namely
    # neuron and tpu), we define the mock ScalarType class here that partially
    # mimics the C++ ScalarType class.
    #
    # We also use this provide type signatures to the Python LSP for the methods
    # in the C++ ScalarType class. So these type signatures should be kept
    # in sync with csrc/core/scalar_type.hpp

    from dataclasses import dataclass

    @dataclass(frozen=True)
    class ScalarType:
        """
34
35
36
        ScalarType can represent a wide range of floating point and integer
        types, in particular it can be used to represent sub-byte data types
        (something that torch.dtype currently does not support). It is also
37
        capable of  representing types with a bias, i.e.:
38
39
40
41
          `stored_value = value + bias`,
        this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
        of 8). The implementation for this class can be found in
        csrc/core/scalar_type.hpp, these type signatures should be kept in sync
42
43
44
45
46
47
48
49
50
51
52
53
        with that file.
        """

        exponent: int
        """
        Number of bits in the exponent if this is a floating point type
        (zero if this an integer type)
        """

        mantissa: int
        """
        Number of bits in the mantissa if this is a floating point type,
54
        or the number bits representing an integer excluding the sign bit if
55
56
57
58
59
        this an integer type.
        """

        bias: int
        """
60
61
62
        bias used to encode the values in this scalar type
        (value = stored_value - bias, default 0) for example if we store the
        type as an unsigned integer with a bias of 128 then the value 0 will be
63
64
65
66
67
68
69
70
71
72
73
74
75
        stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
        """

        signed: bool
        "If the type is signed (i.e. has a sign bit)"

        _finite_values_only: bool = False
        """
        Private: if NANs are supported, used `has_infs()` instead.
        """

        nan_repr: int = NanRepr.IEEE_754.value
        """
76
        How NaNs are represent in this scalar type, returns NanRepr value.
77
78
79
80
81
82
83
84
85
        (not applicable for integer types)
        """

        @property
        def size_bits(self):
            return self.exponent + self.mantissa + int(self.signed)

        def min(self) -> Union[int, float]:
            """
86
            Min representable value for this scalar type.
87
88
89
90
91
92
            (accounting for bias if there is one)
            """
            raise NotImplementedError

        def max(self) -> Union[int, float]:
            """
93
            Max representable value for this scalar type.
94
95
96
97
98
99
100
101
102
103
104
105
            (accounting for bias if there is one)
            """
            raise NotImplementedError

        def is_signed(self) -> bool:
            """
            If the type is signed (i.e. has a sign bit), same as `signed`
            added for consistency with:
            https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
            """
            ...

106
        def is_floating_point(self) -> bool:
107
108
109
            "If the type is a floating point type"
            return self.exponent != 0

110
        def is_integer(self) -> bool:
111
112
113
            "If the type is an integer type"
            return self.exponent == 0

114
        def has_bias(self) -> bool:
115
116
117
            "If the type has a non-zero bias"
            return self.bias != 0

118
        def has_infs(self) -> bool:
119
120
121
            "If the type is floating point and supports infinity"
            return not self._finite_values_only

122
        def has_nans(self) -> bool:
123
124
125
126
            return self.nan_repr != NanRepr.NONE.value

        def is_ieee_754(self) -> bool:
            """
127
            If the type is a floating point type that follows IEEE 754
128
129
130
131
132
133
134
135
136
137
138
            conventions
            """
            return self.nan_repr == NanRepr.IEEE_754.value and \
                not self._finite_values_only

        def __str__(self) -> str:
            raise NotImplementedError

        def __repr__(self) -> str:
            raise NotImplementedError

139
140
141
142
143
        # __len__ needs to be defined (and has to throw TypeError) for pytorch's
        # opcheck to work.
        def __len__(self) -> int:
            raise TypeError

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        #
        # Convenience Constructors
        #

        @classmethod
        def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
            "Create a signed integer scalar type (size_bits includes sign-bit)."
            return cls(size_bits - 1, size_bits, bias if bias else 0, True)

        @classmethod
        def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
            """Create a unsigned integer scalar type."""
            return cls(size_bits, size_bits, bias if bias else 0, False)

        @classmethod
        def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
            """
161
            Create a standard floating point type
162
163
164
165
166
167
            (i.e. follows IEEE 754 conventions).
            """
            return cls(exponent, mantissa, 0, True)

        @classmethod
        def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
168
                   nan_repr: int) -> 'ScalarType':
169
            """
170
            Create a non-standard floating point type
171
172
173
174
175
176
177
178
179
180
181
182
            (i.e. does not follow IEEE 754 conventions).
            """
            return cls(exponent, mantissa, 0, True, finite_values_only,
                       nan_repr)

elif core_C_available:
    try:
        import vllm._core_C  # noqa: F401
    except ImportError as e:
        logger.warning("Failed to import from vllm._core_C with %r", e)

    ScalarType = torch.classes._core_C.ScalarType
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

    # Needed for dynamo support of ScalarType.
    @torch._library.register_fake_class("_core_C::ScalarType")
    class FakeScalarType:

        def __init__(self, scalar_type):
            self.ScalarType = scalar_type

        def bias_getter(self) -> int:
            return self.ScalarType.bias

        def exponent_getter(self) -> int:
            return self.ScalarType.exponent

        def mantissa_getter(self) -> int:
            return self.ScalarType.mantissa

        def signed_getter(self) -> bool:
            return self.ScalarType.signed

        def size_bits_getter(self) -> int:
            return self.ScalarType.size_bits

        @property
        def size_bits(self) -> int:
            return self.ScalarType.size_bits

        def min(self) -> Union[int, float]:
            return self.ScalarType.min()

        def max(self) -> Union[int, float]:
            return self.ScalarType.max()

        def is_signed(self) -> bool:
            return self.ScalarType.is_signed()

        def is_floating_point(self) -> bool:
            return self.ScalarType.is_floating_point()

        def is_integer(self) -> bool:
            return self.ScalarType.is_integer()

        def has_bias(self) -> bool:
            return self.ScalarType.has_bias()

        def has_infs(self) -> bool:
            return self.ScalarType.has_infs()

        def has_nans(self) -> bool:
            return self.ScalarType.has_nans()

        def is_ieee_754(self) -> bool:
            return self.ScalarType.is_ieee_754()

        def __str__(self) -> str:
            return self.ScalarType.__str__()

        def __repr__(self) -> str:
            return self.ScalarType.__repr__()

        def __len__(self) -> int:
            return self.ScalarType.__len__()

        def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
            return torch.classes._core_C.ScalarType.__obj_flatten__(
                self.ScalarType)

        @classmethod
        def __obj_unflatten__(
                cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType':
            return cls(
                torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type))

        @classmethod
        def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
            return ScalarType.int_(size_bits, bias)

        @classmethod
        def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
            return ScalarType.uint(size_bits, bias)

        @classmethod
        def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
            return ScalarType.float_IEEE754(exponent, mantissa)

        @classmethod
        def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
                   nan_repr: int) -> 'ScalarType':
            return ScalarType.float_(exponent, mantissa, finite_values_only,
                                     nan_repr)