dequantizer.py 7.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.

This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
10
11
12
13
import math
from dataclasses import dataclass
from abc import ABC, abstractmethod

14
15
16
17
18
import jax
import jax.numpy as jnp

from .scaling_modes import ScalingMode

19
__all__ = ["ScalingModeToDequantizerMap"]
20
21


22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@dataclass
class Dequantizer(ABC):
    """
    Base Dequantizer Class
    """

    @staticmethod
    @abstractmethod
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        pass

    @staticmethod
    @abstractmethod
    def dequantize(scaled_tensor):
        """Dequantizing given tensor to higher precision."""


class TensorScaleDequantizer(Dequantizer):
    """
    TensorScaling Dequantizer Class
42
43

    This class provides static methods for dequantizing tensors that have been
44
45
    quantized using different tensor scaling modes. It supports both delayed scaling
    and current scaling modes.
46
47
48
    """

    @staticmethod
49
50
51
52
53
54
55
56
57
    def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
        del kwargs
        return jnp.asarray(
            data.astype(jnp.float32) * scale_inv.astype(jnp.float32),
            dq_dtype,
        )

    @staticmethod
    def dequantize(scaled_tensor):
58
59
60
61
62
63
64
65
66
67
68
        """Dequantize a tensor using delayed scaling.

        This function dequantizes a tensor that was quantized using delayed scaling
        by multiplying the quantized data with the inverse scaling factor.

        Args:
            scaled_tensor: The quantized tensor to dequantize

        Returns:
            The dequantized tensor in the specified data type
        """
69
70
        return TensorScaleDequantizer._dequantize_func(
            scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype
71
72
        )

73
74
75
76
77
78
79
80

class BlockScaleDequantizer(Dequantizer):
    """BlockScaling Dequantizer Class.

    This class provides static methods for dequantizing tensors that have been
    quantized using block scaling modes.
    """

81
    @staticmethod
82
    def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis):
83
84
85
        """Dequantize a tensor using block scaling.

        Args:
86
87
88
89
90
91
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            dq_dtype: The data type for dequantized values
            scaling_mode: The scaling mode used for quantization
            is_colwise: Whether the scaling is column-wise
            flatten_axis: The axis along which the tensor could be flattened to 2D
92
93

        Returns:
94
            The dequantized tensor
95
        """
96
97
98
99

        data = data.astype(jnp.float32)
        scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32)

100
        data_shape = data.shape
101
102
103
104
        flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
105
106
        scale_shape = scaling_mode.get_scale_shape(
            data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
107
        )
108
109
110
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(scale_shape), scale_shape
        )  # slice out the padding
111

112
        data = data.reshape(
113
114
115
116
            *data_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *data_shape[flatten_axis:-1],
117
118
119
            scale_shape[-1],
            int(data_shape[-1] / scale_shape[-1]),
        )
120

121
        scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
122

123
124
        # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
        return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape)
125
126
127

    @staticmethod
    def dequantize(scaled_tensor):
128
        """Dequantize a tensor using block scaling.
129

130
131
        This function dequantizes a tensor that was quantized using block scaling
        by applying the inverse scaling factor to each block of data.
132
133
134
135
136
137
138

        Args:
            scaled_tensor: The quantized tensor to dequantize

        Returns:
            The dequantized tensor in the specified data type
        """
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        return BlockScaleDequantizer._dequantize_func(
            scaled_tensor.data,
            scaled_tensor.scale_inv,
            scaled_tensor.dq_dtype,
            scaled_tensor.scaling_mode,
            scaled_tensor.is_colwise,
            scaled_tensor.flatten_axis,
        )


ScalingModeToDequantizerMap = {
    ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
    ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
}


@staticmethod
def _grouped_dequantize(grouped_scaled_tensor):
    """Dequantize a grouped tensor.

    Args:
        grouped_scaled_tensor: The grouped scaled tensor to dequantize

    Returns:
        List of dequantized tensors for each group
    """
    data = grouped_scaled_tensor.data
    scale_inv = grouped_scaled_tensor.scale_inv
    group_sizes = grouped_scaled_tensor.group_sizes
    flatten_axis = grouped_scaled_tensor.flatten_axis
    scaling_mode = grouped_scaled_tensor.scaling_mode
    original_shape = grouped_scaled_tensor.original_shape
    group_axis = grouped_scaled_tensor.group_axis

    flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis

    output = []
    non_group_shape = tuple(
        original_shape[i] for i in range(len(original_shape)) if i != group_axis
    )
    matrix_sizes = group_sizes * math.prod(non_group_shape)

    data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1])

    scale_inv_ptr = 0
    for i, data_i in enumerate(data):
        data_shape_i = (
            *original_shape[:group_axis],
            group_sizes[i],
            *original_shape[group_axis + 1 :],
        )
        assert math.prod(data_shape_i) == data_i.size, (
            f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
            f" {data_i.size}"
        )
        scale_shape_i = scaling_mode.get_scale_shape(
            data_shape_i,
            grouped_scaled_tensor.is_colwise,
            is_padded=True,
            flatten_axis=flatten_axis,
        )
        scale_shape_i_size = math.prod(scale_shape_i)
        scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size]
        dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
        if len(data_i) == 0:
            out_i = []
        else:
            out_i = dequantizer_type._dequantize_func(
                data_i.reshape(data_shape_i),
                scale_inv_i.reshape(scale_shape_i),
                grouped_scaled_tensor.dq_dtype,
                scaling_mode=grouped_scaled_tensor.scaling_mode,
                is_colwise=grouped_scaled_tensor.is_colwise,
                flatten_axis=grouped_scaled_tensor.flatten_axis,
            )
        output.append(out_i)
        scale_inv_ptr += scale_shape_i_size

    return output


Dequantizer.grouped_dequantize = _grouped_dequantize