dequantizer.py 8.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.

This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
10
11
12
13
import math
from dataclasses import dataclass
from abc import ABC, abstractmethod

14
15
16
17
18
import jax
import jax.numpy as jnp

from .scaling_modes import ScalingMode

19
__all__ = ["ScalingModeToDequantizerMap"]
20
21


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@dataclass
class Dequantizer(ABC):
    """
    Base Dequantizer Class
    """

    @staticmethod
    @abstractmethod
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        pass

    @staticmethod
    @abstractmethod
    def dequantize(scaled_tensor):
        """Dequantizing given tensor to higher precision."""


Alp Dener's avatar
Alp Dener committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class NoopDequantizer(Dequantizer):
    """No-op Dequantizer Class"""

    @staticmethod
    def _dequantize_func(data, *args, **kwargs):
        """A no-op dequantize function that returns the data without any changes."""
        del args, kwargs
        return data

    @staticmethod
    def dequantize(scaled_tensor):
        """A no-op dequantize function that simply returns the data array in the ScaledTensor."""
        return scaled_tensor.data


55
56
57
class TensorScaleDequantizer(Dequantizer):
    """
    TensorScaling Dequantizer Class
58
59

    This class provides static methods for dequantizing tensors that have been
60
61
    quantized using different tensor scaling modes. It supports both delayed scaling
    and current scaling modes.
62
63
64
    """

    @staticmethod
65
66
67
68
69
70
71
72
73
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        del kwargs
        return jnp.asarray(
            data.astype(jnp.float32) * scale_inv.astype(jnp.float32),
            dq_dtype,
        )

    @staticmethod
    def dequantize(scaled_tensor):
74
75
76
77
78
79
80
81
82
83
84
        """Dequantize a tensor using delayed scaling.

        This function dequantizes a tensor that was quantized using delayed scaling
        by multiplying the quantized data with the inverse scaling factor.

        Args:
            scaled_tensor: The quantized tensor to dequantize

        Returns:
            The dequantized tensor in the specified data type
        """
85
86
        return TensorScaleDequantizer._dequantize_func(
            scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype
87
88
        )

89
90
91
92
93
94
95
96

class BlockScaleDequantizer(Dequantizer):
    """BlockScaling Dequantizer Class.

    This class provides static methods for dequantizing tensors that have been
    quantized using block scaling modes.
    """

97
    @staticmethod
98
    def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis):
99
100
101
        """Dequantize a tensor using block scaling.

        Args:
102
103
104
105
106
107
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
108
109

        Returns:
110
            The dequantized tensor
111
        """
112
113
114
115

        data = data.astype(jnp.float32)
        scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32)

116
        data_shape = data.shape
117
118
119
120
        flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
121
122
        scale_shape = scaling_mode.get_scale_shape(
            data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
123
        )
124

125
        data = data.reshape(
126
127
128
129
            *data_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *data_shape[flatten_axis:-1],
130
131
132
            scale_shape[-1],
            int(data_shape[-1] / scale_shape[-1]),
        )
133

134
        scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
135

136
137
        # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
        return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape)
138
139
140

    @staticmethod
    def dequantize(scaled_tensor):
141
        """Dequantize a tensor using block scaling.
142
143

        Args:
144
145
146
147
148
149
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
150
151

        Returns:
152
            The dequantized tensor
153
        """
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        return BlockScaleDequantizer._dequantize_func(
            scaled_tensor.data,
            scaled_tensor.scale_inv,
            scaled_tensor.dq_dtype,
            scaled_tensor.scaling_mode,
            scaled_tensor.is_colwise,
            scaled_tensor.flatten_axis,
        )


ScalingModeToDequantizerMap = {
    ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
Alp Dener's avatar
Alp Dener committed
168
    ScalingMode.NO_SCALING: NoopDequantizer,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
}


@staticmethod
def _grouped_dequantize(grouped_scaled_tensor):
    """Dequantize a grouped tensor.

    Args:
        grouped_scaled_tensor: The grouped scaled tensor to dequantize

    Returns:
        List of dequantized tensors for each group
    """
    data = grouped_scaled_tensor.data
    scale_inv = grouped_scaled_tensor.scale_inv
    group_sizes = grouped_scaled_tensor.group_sizes
    flatten_axis = grouped_scaled_tensor.flatten_axis
    scaling_mode = grouped_scaled_tensor.scaling_mode
    original_shape = grouped_scaled_tensor.original_shape
    group_axis = grouped_scaled_tensor.group_axis

    flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis

    output = []
    non_group_shape = tuple(
        original_shape[i] for i in range(len(original_shape)) if i != group_axis
    )
    matrix_sizes = group_sizes * math.prod(non_group_shape)

    data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1])

    scale_inv_ptr = 0
    for i, data_i in enumerate(data):
        data_shape_i = (
            *original_shape[:group_axis],
            group_sizes[i],
            *original_shape[group_axis + 1 :],
        )
        assert math.prod(data_shape_i) == data_i.size, (
            f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
            f" {data_i.size}"
        )
211
        padded_scale_shape_i = scaling_mode.get_scale_shape(
212
213
214
215
216
            data_shape_i,
            grouped_scaled_tensor.is_colwise,
            is_padded=True,
            flatten_axis=flatten_axis,
        )
217
218
219
220
221
222
223
224
225
226
227
228
        unpadded_scale_shape_i = scaling_mode.get_scale_shape(
            data_shape_i,
            grouped_scaled_tensor.is_colwise,
            is_padded=False,
            flatten_axis=flatten_axis,
        )
        scale_inv_i = scale_inv[
            scale_inv_ptr : scale_inv_ptr + math.prod(padded_scale_shape_i)
        ].reshape(padded_scale_shape_i)
        scale_inv_i = jax.lax.slice(
            scale_inv_i, [0] * len(unpadded_scale_shape_i), unpadded_scale_shape_i
        )
229
230
231
232
233
234
        dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
        if len(data_i) == 0:
            out_i = []
        else:
            out_i = dequantizer_type._dequantize_func(
                data_i.reshape(data_shape_i),
235
                scale_inv_i,
236
237
238
239
240
241
                grouped_scaled_tensor.dq_dtype,
                scaling_mode=grouped_scaled_tensor.scaling_mode,
                is_colwise=grouped_scaled_tensor.is_colwise,
                flatten_axis=grouped_scaled_tensor.flatten_axis,
            )
        output.append(out_i)
242
        scale_inv_ptr += math.prod(padded_scale_shape_i)
243
244
245
246
247

    return output


Dequantizer.grouped_dequantize = _grouped_dequantize