_core_ext.py 9.43 KB
Newer Older
1
2
import importlib.util
from enum import Enum
3
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

import torch

from vllm.logger import init_logger

logger = init_logger(__name__)
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None


# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
    NONE = 0  # nans are not supported
    IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
    EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s


if TYPE_CHECKING or not core_C_available:
    # On platforms were we cannot use/build the C++ core extension (i.e. namely
    # neuron and tpu), we define the mock ScalarType class here that partially
    # mimics the C++ ScalarType class.
    #
    # We also use this provide type signatures to the Python LSP for the methods
    # in the C++ ScalarType class. So these type signatures should be kept
    # in sync with csrc/core/scalar_type.hpp

    from dataclasses import dataclass

    @dataclass(frozen=True)
    class ScalarType:
        """
34
35
36
        ScalarType can represent a wide range of floating point and integer
        types, in particular it can be used to represent sub-byte data types
        (something that torch.dtype currently does not support). It is also
37
        capable of  representing types with a bias, i.e.:
38
39
40
41
          `stored_value = value + bias`,
        this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
        of 8). The implementation for this class can be found in
        csrc/core/scalar_type.hpp, these type signatures should be kept in sync
42
43
44
45
46
47
48
49
50
51
52
53
        with that file.
        """

        exponent: int
        """
        Number of bits in the exponent if this is a floating point type
        (zero if this an integer type)
        """

        mantissa: int
        """
        Number of bits in the mantissa if this is a floating point type,
54
        or the number bits representing an integer excluding the sign bit if
55
56
57
58
59
        this an integer type.
        """

        bias: int
        """
60
61
62
        bias used to encode the values in this scalar type
        (value = stored_value - bias, default 0) for example if we store the
        type as an unsigned integer with a bias of 128 then the value 0 will be
63
64
65
66
67
68
69
70
71
72
73
74
75
        stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
        """

        signed: bool
        "If the type is signed (i.e. has a sign bit)"

        _finite_values_only: bool = False
        """
        Private: if NANs are supported, used `has_infs()` instead.
        """

        nan_repr: int = NanRepr.IEEE_754.value
        """
76
        How NaNs are represent in this scalar type, returns NanRepr value.
77
78
79
80
81
82
83
84
85
        (not applicable for integer types)
        """

        @property
        def size_bits(self):
            return self.exponent + self.mantissa + int(self.signed)

        def min(self) -> Union[int, float]:
            """
86
            Min representable value for this scalar type.
87
88
89
90
91
92
            (accounting for bias if there is one)
            """
            raise NotImplementedError

        def max(self) -> Union[int, float]:
            """
93
            Max representable value for this scalar type.
94
95
96
97
98
99
100
101
102
103
104
105
            (accounting for bias if there is one)
            """
            raise NotImplementedError

        def is_signed(self) -> bool:
            """
            If the type is signed (i.e. has a sign bit), same as `signed`
            added for consistency with:
            https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
            """
            ...

106
        def is_floating_point(self) -> bool:
107
108
109
            "If the type is a floating point type"
            return self.exponent != 0

110
        def is_integer(self) -> bool:
111
112
113
            "If the type is an integer type"
            return self.exponent == 0

114
        def has_bias(self) -> bool:
115
116
117
            "If the type has a non-zero bias"
            return self.bias != 0

118
        def has_infs(self) -> bool:
119
120
121
            "If the type is floating point and supports infinity"
            return not self._finite_values_only

122
        def has_nans(self) -> bool:
123
124
125
126
            return self.nan_repr != NanRepr.NONE.value

        def is_ieee_754(self) -> bool:
            """
127
            If the type is a floating point type that follows IEEE 754
128
129
130
131
132
133
134
135
136
137
138
            conventions
            """
            return self.nan_repr == NanRepr.IEEE_754.value and \
                not self._finite_values_only

        def __str__(self) -> str:
            raise NotImplementedError

        def __repr__(self) -> str:
            raise NotImplementedError

139
140
141
142
143
        # __len__ needs to be defined (and has to throw TypeError) for pytorch's
        # opcheck to work.
        def __len__(self) -> int:
            raise TypeError

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        #
        # Convenience Constructors
        #

        @classmethod
        def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
            "Create a signed integer scalar type (size_bits includes sign-bit)."
            return cls(size_bits - 1, size_bits, bias if bias else 0, True)

        @classmethod
        def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
            """Create a unsigned integer scalar type."""
            return cls(size_bits, size_bits, bias if bias else 0, False)

        @classmethod
        def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
            """
161
            Create a standard floating point type
162
163
164
165
166
167
            (i.e. follows IEEE 754 conventions).
            """
            return cls(exponent, mantissa, 0, True)

        @classmethod
        def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
168
                   nan_repr: int) -> 'ScalarType':
169
            """
170
            Create a non-standard floating point type
171
172
173
174
175
176
177
178
179
180
181
182
            (i.e. does not follow IEEE 754 conventions).
            """
            return cls(exponent, mantissa, 0, True, finite_values_only,
                       nan_repr)

elif core_C_available:
    try:
        import vllm._core_C  # noqa: F401
    except ImportError as e:
        logger.warning("Failed to import from vllm._core_C with %r", e)

    ScalarType = torch.classes._core_C.ScalarType
183

184
185
186
187
188
    if (hasattr(torch, "_library")
            and hasattr(torch._library, "register_fake_class")):
        # Needed for dynamo support of ScalarType.
        @torch._library.register_fake_class("_core_C::ScalarType")
        class FakeScalarType:
189

190
191
            def __init__(self, scalar_type):
                self.ScalarType = scalar_type
192

193
194
            def bias_getter(self) -> int:
                return self.ScalarType.bias
195

196
197
            def exponent_getter(self) -> int:
                return self.ScalarType.exponent
198

199
200
            def mantissa_getter(self) -> int:
                return self.ScalarType.mantissa
201

202
203
            def signed_getter(self) -> bool:
                return self.ScalarType.signed
204

205
206
            def size_bits_getter(self) -> int:
                return self.ScalarType.size_bits
207

208
209
210
            @property
            def size_bits(self) -> int:
                return self.ScalarType.size_bits
211

212
213
            def min(self) -> Union[int, float]:
                return self.ScalarType.min()
214

215
216
            def max(self) -> Union[int, float]:
                return self.ScalarType.max()
217

218
219
            def is_signed(self) -> bool:
                return self.ScalarType.is_signed()
220

221
222
            def is_floating_point(self) -> bool:
                return self.ScalarType.is_floating_point()
223

224
225
            def is_integer(self) -> bool:
                return self.ScalarType.is_integer()
226

227
228
            def has_bias(self) -> bool:
                return self.ScalarType.has_bias()
229

230
231
            def has_infs(self) -> bool:
                return self.ScalarType.has_infs()
232

233
234
            def has_nans(self) -> bool:
                return self.ScalarType.has_nans()
235

236
237
            def is_ieee_754(self) -> bool:
                return self.ScalarType.is_ieee_754()
238

239
240
            def __str__(self) -> str:
                return self.ScalarType.__str__()
241

242
243
            def __repr__(self) -> str:
                return self.ScalarType.__repr__()
244

245
246
            def __len__(self) -> int:
                return self.ScalarType.__len__()
247

248
249
250
            def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
                return torch.classes._core_C.ScalarType.__obj_flatten__(
                    self.ScalarType)
251

252
253
254
255
256
257
258
            @classmethod
            def __obj_unflatten__(
                    cls, flat_type: Tuple[Tuple[str, Any],
                                          ...]) -> 'ScalarType':
                return cls(
                    torch.classes._core_C.ScalarType.__obj_unflatten__(
                        flat_type))
259

260
261
262
            @classmethod
            def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
                return ScalarType.int_(size_bits, bias)
263

264
265
266
            @classmethod
            def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
                return ScalarType.uint(size_bits, bias)
267

268
269
270
271
            @classmethod
            def float_IEEE754(cls, exponent: int,
                              mantissa: int) -> 'ScalarType':
                return ScalarType.float_IEEE754(exponent, mantissa)
272

273
274
275
276
277
278
            @classmethod
            def float_(cls, exponent: int, mantissa: int,
                       finite_values_only: bool,
                       nan_repr: int) -> 'ScalarType':
                return ScalarType.float_(exponent, mantissa,
                                         finite_values_only, nan_repr)