dequantizer.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.

This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
10
11
12
13
import math
from dataclasses import dataclass
from abc import ABC, abstractmethod

14
15
16
17
import jax
import jax.numpy as jnp

from .scaling_modes import ScalingMode
18
19
from .hadamard import apply_rht, should_use_rht

20

21
__all__ = ["ScalingModeToDequantizerMap"]
22
23


24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@dataclass
class Dequantizer(ABC):
    """
    Base Dequantizer Class
    """

    @staticmethod
    @abstractmethod
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        pass

    @staticmethod
    @abstractmethod
    def dequantize(scaled_tensor):
        """Dequantizing given tensor to higher precision."""


Alp Dener's avatar
Alp Dener committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@dataclass
class NoopDequantizer(Dequantizer):
    """No-op Dequantizer Class"""

    @staticmethod
    def _dequantize_func(data, *args, **kwargs):
        """A no-op dequantize function that returns the data without any changes."""
        del args, kwargs
        return data

    @staticmethod
    def dequantize(scaled_tensor):
        """A no-op dequantize function that simply returns the data array in the ScaledTensor."""
        return scaled_tensor.data


57
58
59
class TensorScaleDequantizer(Dequantizer):
    """
    TensorScaling Dequantizer Class
60
61

    This class provides static methods for dequantizing tensors that have been
62
63
    quantized using different tensor scaling modes. It supports both delayed scaling
    and current scaling modes.
64
65
66
    """

    @staticmethod
67
68
69
70
71
72
73
74
75
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        del kwargs
        return jnp.asarray(
            data.astype(jnp.float32) * scale_inv.astype(jnp.float32),
            dq_dtype,
        )

    @staticmethod
    def dequantize(scaled_tensor):
76
77
78
79
80
81
82
83
84
85
86
        """Dequantize a tensor using delayed scaling.

        This function dequantizes a tensor that was quantized using delayed scaling
        by multiplying the quantized data with the inverse scaling factor.

        Args:
            scaled_tensor: The quantized tensor to dequantize

        Returns:
            The dequantized tensor in the specified data type
        """
87
88
        return TensorScaleDequantizer._dequantize_func(
            scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype
89
90
        )

91
92
93
94
95
96
97
98

class BlockScaleDequantizer(Dequantizer):
    """BlockScaling Dequantizer Class.

    This class provides static methods for dequantizing tensors that have been
    quantized using block scaling modes.
    """

99
    @staticmethod
100
    def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis):
101
102
103
        """Dequantize a tensor using block scaling.

        Args:
104
105
106
107
108
109
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
110
111

        Returns:
112
            The dequantized tensor
113
        """
114
115
116
117

        data = data.astype(jnp.float32)
        scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32)

118
        data_shape = data.shape
119
120
121
122
        flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
123
        scale_shape = scaling_mode.get_scale_shape(
124
            data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
125
        )
126

127
        data = data.reshape(
128
129
130
131
            *data_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *data_shape[flatten_axis:-1],
132
133
134
            scale_shape[-1],
            int(data_shape[-1] / scale_shape[-1]),
        )
135

136
        scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
137

138
139
        # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
        return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape)
140
141
142

    @staticmethod
    def dequantize(scaled_tensor):
143
        """Dequantize a tensor using block scaling.
144
145

        Args:
146
147
148
149
150
151
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
152
153

        Returns:
154
            The dequantized tensor
155
        """
156
157
158
159
160
161
162
163
164
165
        return BlockScaleDequantizer._dequantize_func(
            scaled_tensor.data,
            scaled_tensor.scale_inv,
            scaled_tensor.dq_dtype,
            scaled_tensor.scaling_mode,
            scaled_tensor.is_colwise,
            scaled_tensor.flatten_axis,
        )


166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class NVFP4Dequantizer(Dequantizer):
    """NVFP4 Dequantizer Class.

    This class provides static methods for dequantizing tensors that have been
    quantized using NVFP4 scaling modes.
    """

    @staticmethod
    def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis):
        """Dequantize a tensor using block scaling.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            amax: The maximum absolute value of the tensor
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D

        Returns:
            The dequantized tensor
        """

        DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32)
        SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32)
        tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)

        data = data.astype(jnp.float32)
        scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv
        data_layout = "T" if is_colwise else "N"

        data_shape = data.shape
        flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
        scale_shape = scaling_mode.get_scale_shape(
            data_shape,
            data_layout=data_layout,
            is_colwise=is_colwise,
            is_padded=False,
            # expect the flatten_axis wrt the N layout
            flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis,
            broadcast_2d_scale_shape_to_1d=True,
        )

        data = data.reshape(
            *data_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *data_shape[flatten_axis:-1],
            scale_shape[-1],
            int(data_shape[-1] / scale_shape[-1]),
        )

        scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
        out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape)

        # Apply inverse of RHT if needed
        use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise)
        if use_rht:
            out = apply_rht(out, inverse=True)

        return out

    @staticmethod
    def dequantize(scaled_tensor):
        """Dequantize a tensor using block scaling.

        Args:
            scaled_tensor: The quantized tensor to dequantize

        Returns:
            The dequantized tensor
        """
        return NVFP4Dequantizer._dequantize_func(
            scaled_tensor.data,
            scaled_tensor.scale_inv,
            scaled_tensor.amax,
            scaled_tensor.dq_dtype,
            scaled_tensor.scaling_mode,
            scaled_tensor.is_colwise,
            scaled_tensor.flatten_axis,
        )


253
254
255
256
ScalingModeToDequantizerMap = {
    ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
257
258
    ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer,
    ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer,
Alp Dener's avatar
Alp Dener committed
259
    ScalingMode.NO_SCALING: NoopDequantizer,
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
}


@staticmethod
def _grouped_dequantize(grouped_scaled_tensor):
    """Dequantize a grouped tensor.

    Args:
        grouped_scaled_tensor: The grouped scaled tensor to dequantize

    Returns:
        List of dequantized tensors for each group
    """
    data = grouped_scaled_tensor.data
    scale_inv = grouped_scaled_tensor.scale_inv
    group_sizes = grouped_scaled_tensor.group_sizes
    flatten_axis = grouped_scaled_tensor.flatten_axis
    scaling_mode = grouped_scaled_tensor.scaling_mode
    original_shape = grouped_scaled_tensor.original_shape
    group_axis = grouped_scaled_tensor.group_axis

    flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis

    output = []
    non_group_shape = tuple(
        original_shape[i] for i in range(len(original_shape)) if i != group_axis
    )
    matrix_sizes = group_sizes * math.prod(non_group_shape)

    data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1])

    scale_inv_ptr = 0
    for i, data_i in enumerate(data):
        data_shape_i = (
            *original_shape[:group_axis],
            group_sizes[i],
            *original_shape[group_axis + 1 :],
        )
        assert math.prod(data_shape_i) == data_i.size, (
            f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
            f" {data_i.size}"
        )
302
        padded_scale_shape_i = scaling_mode.get_scale_shape(
303
            data_shape_i,
304
            is_colwise=grouped_scaled_tensor.is_colwise,
305
306
307
            is_padded=True,
            flatten_axis=flatten_axis,
        )
308
309
        unpadded_scale_shape_i = scaling_mode.get_scale_shape(
            data_shape_i,
310
            is_colwise=grouped_scaled_tensor.is_colwise,
311
312
313
314
315
316
317
318
319
            is_padded=False,
            flatten_axis=flatten_axis,
        )
        scale_inv_i = scale_inv[
            scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)
        ].reshape(padded_scale_shape_i)
        scale_inv_i = jax.lax.slice(
            scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i
        )
320
321
322
323
324
325
        dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
        if len(data_i) == 0:
            out_i = []
        else:
            out_i = dequantizer_type._dequantize_func(
                data_i.reshape(data_shape_i),
326
                scale_inv_i,
327
328
329
330
331
332
                grouped_scaled_tensor.dq_dtype,
                scaling_mode=grouped_scaled_tensor.scaling_mode,
                is_colwise=grouped_scaled_tensor.is_colwise,
                flatten_axis=grouped_scaled_tensor.flatten_axis,
            )
        output.append(out_i)
333
        scale_inv_ptr += math.prod(padded_scale_shape_i)
334
335
336
337
338

    return output


Dequantizer.grouped_dequantize = _grouped_dequantize