tensor.py 15.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

18
from transformer_engine_jax import QuantizeLayout
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
    """Abstract base class for scaled tensors.

    This class defines the interface for all scaled tensor implementations,
    providing methods for dequantization and accessing row/column-wise components.
    """

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
    def get_rowwise_tensor(self):
        """Returns the row-wise component of the tensor.

        Returns:
            The row-wise tensor component

        Raises:
            ValueError: If called on a tensor that doesn't support row-wise access
        """

    @abstractmethod
    def get_colwise_tensor(self):
        """Returns the column-wise component of the tensor.

        Returns:
            The column-wise tensor component

        Raises:
            ValueError: If called on a tensor that doesn't support column-wise access
        """

87
88
89
90
91
92
93
94
95
96
97
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
114
115
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
116
117
118
119
120
121
122
123
    """

    data: jnp.ndarray
    scale_inv: jnp.ndarray
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
124
125
    data_layout: str
    flatten_axis: int = -1
126
127
128
129
130
131
132

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
133
134
135
136
137
138
139
140
141
142
143
        flatten_axis = (
            len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
        )
        assert (
            0 < flatten_axis < len(self.data.shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"

        if self.data_layout == "T":
            flatten_axis = self.data.ndim - flatten_axis
        self.flatten_axis = flatten_axis

144
        expected_scale_shape = self.scaling_mode.get_scale_shape(
145
            self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
146
147
        )
        expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
148
            self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        )
        if self.scale_inv.shape != expected_scale_shape:
            assert self.scale_inv.shape == expected_unpadded_scale_shape, (
                f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
                f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got"
                f" {self.scale_inv.shape}"
            )
            pad_width = tuple(
                (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
            )
            # This actually pad scale_inv with nan, should we pad it with 127 directly instead?
            self.scale_inv = jnp.pad(
                self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0
            )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.scale_inv)
171
172
173
174
175
176
177
178
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
        )
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        return (children, aux_data)

    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

    def get_rowwise_tensor(self):
        """Returns the tensor if it's row-wise quantized.

        Returns:
            The row-wise tensor

        Raises:
            ValueError: If called on a column-wise quantized tensor
        """
        if not self.is_colwise:
            return self

        raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")

    def get_colwise_tensor(self):
        """Returns the tensor if it's column-wise quantized.

        Returns:
            The column-wise tensor

        Raises:
            ValueError: If called on a row-wise quantized tensor
        """
        if self.is_colwise:
            return self

        raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
            flatten_axis = -self.flatten_axis
            axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis])
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

        if self.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
            # TODO(Phuong): Handle padding !?
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        # TODO(Phuong): constaint padded scale_inv?
        return ScaledTensor1x(
            data=data,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
        )

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

    def get_rowwise_tensor(self):
        """Returns the row-wise quantized component.

        Returns:
            The row-wise tensor component
        """
        return self.rowwise_tensor

    def get_colwise_tensor(self):
        """Returns the column-wise quantized component.

        Returns:
            The column-wise tensor component
        """
        return self.colwise_tensor

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

328
329
330
331
332
333
334
335
336
337
338

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
339
340
341
342
343
344
345
        data,
        scale_inv,
        scaling_mode,
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
346
347
348
349
350
351
352
353
354
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
355
356
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
357
358
359
360
361

        Returns:
            A ScaledTensor1x instance
        """
        dq_func = Dequantizer.funcs.get(scaling_mode)
362
363
364
        return ScaledTensor1x(
            data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis
        )
365
366
367
368
369
370
371
372
373

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
        scaling_mode,
        dq_dtype=jnp.bfloat16,
374
375
        data_layout="NN",
        flatten_axis=-1,
376
377
378
379
380
381
382
383
384
385
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
386
387
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
388
389
390
391
392
393
394
395
396
397
398
399

        Returns:
            A ScaledTensor2x instance
        """
        dq_func = Dequantizer.funcs.get(scaling_mode)
        rowwise_tensor = ScaledTensor1x(
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            dq_func,
            is_colwise=False,
400
401
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
402
403
404
405
406
407
408
409
        )
        colwise_tensor = ScaledTensor1x(
            colwise_data,
            colwise_scale_inv,
            scaling_mode,
            dq_dtype,
            dq_func,
            is_colwise=True,
410
411
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
412
413
414
415
416
417
418
419
420
421
422
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
        scaling_mode: ScalingMode,
        dq_dtype: jnp.dtype = jnp.bfloat16,
423
424
425
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
426
427
428
429
430
431
432
433
434
435
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
436
437
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
438
439

        Returns:
440
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
441
        """
442
        if q_layout == QuantizeLayout.ROWWISE_COLWISE:
443
444
445
446
447
448
449
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
                scaling_mode,
                dq_dtype,
450
451
                data_layout=data_layout,
                flatten_axis=flatten_axis,
452
453
            )

454
        is_colwise = q_layout == QuantizeLayout.COLWISE
455
        return ScaledTensorFactory.create_1x(
456
457
458
459
460
461
462
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=is_colwise,
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
463
464
465
466
467
468
469
470
471
472
473
474
475
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
476
477
    if isinstance(x, ScaledTensor):
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
478
479

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)