dequantizer.py 7.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.

This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
10
11
12
13
import math
from dataclasses import dataclass
from abc import ABC, abstractmethod

14
15
16
17
18
import jax
import jax.numpy as jnp

from .scaling_modes import ScalingMode

19
__all__ = ["ScalingModeToDequantizerMap"]
20
21


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@dataclass
class Dequantizer(ABC):
    """
    Base Dequantizer Class
    """

    @staticmethod
    @abstractmethod
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        pass

    @staticmethod
    @abstractmethod
    def dequantize(scaled_tensor):
        """Dequantizing given tensor to higher precision."""


Alp Dener's avatar
Alp Dener committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@dataclass
class NoopDequantizer(Dequantizer):
    """No-op Dequantizer Class"""

    @staticmethod
    def _dequantize_func(data, *args, **kwargs):
        """A no-op dequantize function that returns the data without any changes."""
        del args, kwargs
        return data

    @staticmethod
    def dequantize(scaled_tensor):
        """A no-op dequantize function that simply returns the data array in the ScaledTensor."""
        return scaled_tensor.data


55
56
57
class TensorScaleDequantizer(Dequantizer):
    """
    TensorScaling Dequantizer Class
58
59

    This class provides static methods for dequantizing tensors that have been
60
61
    quantized using different tensor scaling modes. It supports both delayed scaling
    and current scaling modes.
62
63
64
    """

    @staticmethod
65
66
67
68
69
70
71
72
73
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        del kwargs
        return jnp.asarray(
            data.astype(jnp.float32) * scale_inv.astype(jnp.float32),
            dq_dtype,
        )

    @staticmethod
    def dequantize(scaled_tensor):
74
75
76
77
78
79
80
81
82
83
84
        """Dequantize a tensor using delayed scaling.

        This function dequantizes a tensor that was quantized using delayed scaling
        by multiplying the quantized data with the inverse scaling factor.

        Args:
            scaled_tensor: The quantized tensor to dequantize

        Returns:
            The dequantized tensor in the specified data type
        """
85
86
        return TensorScaleDequantizer._dequantize_func(
            scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype
87
88
        )

89
90
91
92
93
94
95
96

class BlockScaleDequantizer(Dequantizer):
    """BlockScaling Dequantizer Class.

    This class provides static methods for dequantizing tensors that have been
    quantized using block scaling modes.
    """

97
    @staticmethod
98
    def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis):
99
100
101
        """Dequantize a tensor using block scaling.

        Args:
102
103
104
105
106
107
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
108
109

        Returns:
110
            The dequantized tensor
111
        """
112
113
114
115

        data = data.astype(jnp.float32)
        scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32)

116
        data_shape = data.shape
117
118
119
120
        flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
121
122
        scale_shape = scaling_mode.get_scale_shape(
            data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
123
        )
124
125
126
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(scale_shape), scale_shape
        )  # slice out the padding
127

128
        data = data.reshape(
129
130
131
132
            *data_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *data_shape[flatten_axis:-1],
133
134
135
            scale_shape[-1],
            int(data_shape[-1] / scale_shape[-1]),
        )
136

137
        scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
138

139
140
        # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
        return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape)
141
142
143

    @staticmethod
    def dequantize(scaled_tensor):
144
        """Dequantize a tensor using block scaling.
145
146

        Args:
147
148
149
150
151
152
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
153
154

        Returns:
155
            The dequantized tensor
156
        """
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        return BlockScaleDequantizer._dequantize_func(
            scaled_tensor.data,
            scaled_tensor.scale_inv,
            scaled_tensor.dq_dtype,
            scaled_tensor.scaling_mode,
            scaled_tensor.is_colwise,
            scaled_tensor.flatten_axis,
        )


ScalingModeToDequantizerMap = {
    ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
Alp Dener's avatar
Alp Dener committed
171
    ScalingMode.NO_SCALING: NoopDequantizer,
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
}


@staticmethod
def _grouped_dequantize(grouped_scaled_tensor):
    """Dequantize a grouped tensor.

    Args:
        grouped_scaled_tensor: The grouped scaled tensor to dequantize

    Returns:
        List of dequantized tensors for each group
    """
    data = grouped_scaled_tensor.data
    scale_inv = grouped_scaled_tensor.scale_inv
    group_sizes = grouped_scaled_tensor.group_sizes
    flatten_axis = grouped_scaled_tensor.flatten_axis
    scaling_mode = grouped_scaled_tensor.scaling_mode
    original_shape = grouped_scaled_tensor.original_shape
    group_axis = grouped_scaled_tensor.group_axis

    flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis

    output = []
    non_group_shape = tuple(
        original_shape[i] for i in range(len(original_shape)) if i != group_axis
    )
    matrix_sizes = group_sizes * math.prod(non_group_shape)

    data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1])

    scale_inv_ptr = 0
    for i, data_i in enumerate(data):
        data_shape_i = (
            *original_shape[:group_axis],
            group_sizes[i],
            *original_shape[group_axis + 1 :],
        )
        assert math.prod(data_shape_i) == data_i.size, (
            f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
            f" {data_i.size}"
        )
        scale_shape_i = scaling_mode.get_scale_shape(
            data_shape_i,
            grouped_scaled_tensor.is_colwise,
            is_padded=True,
            flatten_axis=flatten_axis,
        )
        scale_shape_i_size = math.prod(scale_shape_i)
        scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size]
        dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
        if len(data_i) == 0:
            out_i = []
        else:
            out_i = dequantizer_type._dequantize_func(
                data_i.reshape(data_shape_i),
                scale_inv_i.reshape(scale_shape_i),
                grouped_scaled_tensor.dq_dtype,
                scaling_mode=grouped_scaled_tensor.scaling_mode,
                is_colwise=grouped_scaled_tensor.is_colwise,
                flatten_axis=grouped_scaled_tensor.flatten_axis,
            )
        output.append(out_i)
        scale_inv_ptr += scale_shape_i_size

    return output


Dequantizer.grouped_dequantize = _grouped_dequantize