tensor.py 22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

18
from transformer_engine_jax import QuantizeLayout
19
20

from .scaling_modes import ScalingMode
21
from .dequantizer import ScalingModeToDequantizerMap
22
23
24
25
26
27
28
29
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
30
    "GroupedScaledTensor1x",
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
    """Abstract base class for scaled tensors.

    This class defines the interface for all scaled tensor implementations,
    providing methods for dequantization and accessing row/column-wise components.
    """

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
    def get_rowwise_tensor(self):
        """Returns the row-wise component of the tensor.

        Returns:
            The row-wise tensor component

        Raises:
            ValueError: If called on a tensor that doesn't support row-wise access
        """

    @abstractmethod
    def get_colwise_tensor(self):
        """Returns the column-wise component of the tensor.

        Returns:
            The column-wise tensor component

        Raises:
            ValueError: If called on a tensor that doesn't support column-wise access
        """

88
89
90
91
92
93
94
95
96
97
98
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
115
116
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
117
118
119
120
121
122
123
124
    """

    data: jnp.ndarray
    scale_inv: jnp.ndarray
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
125
    data_layout: str
126
    flatten_axis: int
127
128
129
130
131
132
133

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
134
135
136
137
138
139
140
141
142
143
144
        flatten_axis = (
            len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
        )
        assert (
            0 < flatten_axis < len(self.data.shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}"

        if self.data_layout == "T":
            flatten_axis = self.data.ndim - flatten_axis
        self.flatten_axis = flatten_axis

145
        expected_scale_shape = self.scaling_mode.get_scale_shape(
146
            self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis
147
148
        )
        expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
149
            self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        )
        if self.scale_inv.shape != expected_scale_shape:
            assert self.scale_inv.shape == expected_unpadded_scale_shape, (
                f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
                f" scale_inv or {expected_unpadded_scale_shape} for unpadded scale_inv, got"
                f" {self.scale_inv.shape}"
            )
            pad_width = tuple(
                (0, a - b) for a, b in zip(expected_scale_shape, expected_unpadded_scale_shape)
            )
            # This actually pad scale_inv with nan, should we pad it with 127 directly instead?
            self.scale_inv = jnp.pad(
                self.scale_inv, pad_width=pad_width, mode="constant", constant_values=0
            )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.scale_inv)
172
173
174
175
176
177
178
179
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
        )
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        return (children, aux_data)

    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

    def get_rowwise_tensor(self):
        """Returns the tensor if it's row-wise quantized.

        Returns:
            The row-wise tensor

        Raises:
            ValueError: If called on a column-wise quantized tensor
        """
        if not self.is_colwise:
            return self

        raise ValueError("Calling get_rowwise_tensor() from a colwise ScaledTensor1x!")

    def get_colwise_tensor(self):
        """Returns the tensor if it's column-wise quantized.

        Returns:
            The column-wise tensor

        Raises:
            ValueError: If called on a row-wise quantized tensor
        """
        if self.is_colwise:
            return self

        raise ValueError("Calling get_colwise_tensor() from a rowwise ScaledTensor1x!")

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
233
234
235
236
237
238
            assert len(logical_axis_names) == self.data.ndim
            flatten_axis = self.data.ndim - self.flatten_axis
            axis_names = (
                *logical_axis_names[flatten_axis:],
                *logical_axis_names[:flatten_axis],
            )
239
240
241
242
243
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

244
        if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            # TODO(Phuong): Handle padding !?
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        return ScaledTensor1x(
            data=data,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
        )

261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
    """Grouped Quantizer for an array.

    This class extends ScaledTensor1x to support quantization of an array in grouped manner,
    where elements are grouped along a specified axis.

    Attributes:
        group_sizes: Array containing the size of each group
        original_shape: The original shape of the tensor before grouping
        group_axis: The axis along which grouping is performed (default: 0)
    """

    group_sizes: jnp.ndarray
    original_shape: Tuple
    group_axis: int

    def __init__(
        self,
        data,
        scale_inv,
        group_sizes,
        scaling_mode,
        dq_dtype,
        _dq_func,
        is_colwise,
        data_layout,
        flatten_axis,
        original_shape,
        group_axis=0,
    ):
        self.group_sizes = group_sizes
        self.original_shape = original_shape
        self.group_axis = group_axis
        super().__init__(
            data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis
        )

    def __post_init__(self):
        assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
        assert self.data.ndim == 1, "Only support flattened data"

        data_ndim = len(self.original_shape)
        flatten_axis = data_ndim + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
        assert (
            0 < flatten_axis < data_ndim
        ), f"flatten_axis {flatten_axis} is out of bounds for data.ndim = {data_ndim}"

        group_axis = (
            len(self.original_shape) + self.group_axis if self.group_axis < 0 else self.group_axis
        )
        assert (
            0 <= group_axis < data_ndim
        ), f"group_axis {group_axis} is out of bounds for shape {self.original_shape}"

        if self.data_layout == "T":
            if self.original_shape[0] == self.group_sizes.size:
                self.original_shape = (
                    self.original_shape[0],
                    *self.original_shape[flatten_axis:],
                    *self.original_shape[1:flatten_axis],
                )
                flatten_axis = len(self.original_shape) - flatten_axis + 1
            else:
                self.original_shape = (
                    *self.original_shape[flatten_axis:],
                    *self.original_shape[:flatten_axis],
                )
                self.group_axis = flatten_axis
                flatten_axis = len(self.original_shape) - flatten_axis

        self.flatten_axis = flatten_axis
        expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
            self.original_shape,
            self.group_sizes.size,
            self.group_axis,
            self.is_colwise,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        assert self.scale_inv.shape == expected_scale_shape, (
            f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
            f" scale_inv, got {self.scale_inv.shape}"
        )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.scale_inv, self.group_sizes)
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
            self.original_shape,
            self.group_axis,
        )
        return (children, aux_data)

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        raise NotImplementedError


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

    def get_rowwise_tensor(self):
        """Returns the row-wise quantized component.

        Returns:
            The row-wise tensor component
        """
        return self.rowwise_tensor

    def get_colwise_tensor(self):
        """Returns the column-wise quantized component.

        Returns:
            The column-wise tensor component
        """
        return self.colwise_tensor

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

442
443
444
445
446
447
448
449
450
451
452

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
453
454
455
456
457
458
459
        data,
        scale_inv,
        scaling_mode,
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
460
461
462
        group_sizes=None,
        original_shape=None,
        group_axis=0,
463
464
465
466
467
468
469
470
471
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
472
473
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
474
475
476
            group_sizes: Arra of ints containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
477
478

        Returns:
479
            A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
480
        """
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
        if group_sizes is not None:
            assert (
                original_shape is not None
            ), "original_shape is not given for GroupedScaledTensor1x"
            return GroupedScaledTensor1x(
                data=data,
                scale_inv=scale_inv,
                scaling_mode=scaling_mode,
                dq_dtype=dq_dtype,
                _dq_func=dequantizer.grouped_dequantize,
                is_colwise=is_colwise,
                data_layout=data_layout,
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

500
        return ScaledTensor1x(
501
502
503
504
505
506
507
508
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            dequantizer.dequantize,
            is_colwise,
            data_layout,
            flatten_axis,
509
        )
510
511
512
513
514
515
516
517
518

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
        scaling_mode,
        dq_dtype=jnp.bfloat16,
519
520
        data_layout="NN",
        flatten_axis=-1,
521
522
523
        group_sizes=None,
        original_shape=None,
        group_axis=0,
524
525
526
527
528
529
530
531
532
533
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
534
535
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
536
537
538
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
539
540
541
542

        Returns:
            A ScaledTensor2x instance
        """
543
544
        assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
        rowwise_tensor = ScaledTensorFactory.create_1x(
545
546
547
548
549
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=False,
550
551
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
552
553
554
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
555
        )
556
        colwise_tensor = ScaledTensorFactory.create_1x(
557
558
559
560
561
            colwise_data,
            colwise_scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=True,
562
563
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
564
565
566
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
567
568
569
570
571
572
573
574
575
576
577
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
        scaling_mode: ScalingMode,
        dq_dtype: jnp.dtype = jnp.bfloat16,
578
579
580
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
581
582
583
        group_sizes: jnp.ndarray = None,
        original_shape: Tuple[int] = None,
        group_axis: int = 0,
584
585
586
587
588
589
590
591
592
593
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
594
595
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
596
597
598
599
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
600
601

        Returns:
602
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
603
        """
604
        if q_layout == QuantizeLayout.ROWWISE_COLWISE:
605
606
607
608
609
610
611
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
                scaling_mode,
                dq_dtype,
612
613
                data_layout=data_layout,
                flatten_axis=flatten_axis,
614
615
616
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
617
618
            )

619
        is_colwise = q_layout == QuantizeLayout.COLWISE
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        if is_colwise:
            return ScaledTensorFactory.create_1x(
                colwise_data,
                colwise_scale_inv,
                scaling_mode,
                dq_dtype,
                is_colwise=is_colwise,
                data_layout=data_layout[0],
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

634
        return ScaledTensorFactory.create_1x(
635
636
637
638
639
640
641
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=is_colwise,
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
642
643
644
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
645
646
647
648
649
650
651
652
653
654
655
656
657
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
658
659
660
    if isinstance(x, GroupedScaledTensor1x):
        raise NotImplementedError

661
662
    if isinstance(x, ScaledTensor):
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
663
664

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)