tensor.py 22.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

18
from transformer_engine_jax import QuantizeLayout
19

20
from .scaling_modes import ScalingMode, TensorUsage
21
from .dequantizer import ScalingModeToDequantizerMap
22
23
24
25
26
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
27
    "TensorUsage",
28
29
30
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
31
    "GroupedScaledTensor1x",
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
    """Abstract base class for scaled tensors.

    This class defines the interface for all scaled tensor implementations,
    providing methods for dequantization and accessing row/column-wise components.
    """

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

Alp Dener's avatar
Alp Dener committed
59
60
61
62
63
    @property
    @abstractmethod
    def ndim(self):
        """Number of dimensions of the underlying quantized array."""

64
65
66
67
68
69
70
71
72
    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
73
74
75
    def get_tensor(self, usage: TensorUsage):
        """Returns the appropriate tensor based on the tensor usage and the scaling mode.
        If the tensor usage is not valid for the scaling mode, an error is raised.
76

77
78
        Args:
            usage: The usage of the tensor
79
80

        Returns:
81
            The tensor based on the usage
82
83
        """

84
85
86
87
88
89
90
91
92
93
94
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

95
96
97
98
99
100
101
102
103
104
105
106

@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
107
        amax: The maximum absolute value of the tensor
108
109
110
111
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
112
113
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
114
115
116
117
    """

    data: jnp.ndarray
    scale_inv: jnp.ndarray
118
    amax: jnp.ndarray
119
120
121
122
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
123
    data_layout: str
124
    flatten_axis: int
125
126
127
128
129
130
131

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
132
        assert self.flatten_axis > 0
133
        assert (
134
135
            0 < self.flatten_axis < len(self.data.shape)
        ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
136

Alp Dener's avatar
Alp Dener committed
137
138
139
        if self.scaling_mode == ScalingMode.NO_SCALING:
            self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
        else:
140
            unpadded_scale_shape = self.scaling_mode.get_scale_shape(
Alp Dener's avatar
Alp Dener committed
141
142
                self.data.shape,
                is_colwise=self.is_colwise,
143
                is_padded=False,
Alp Dener's avatar
Alp Dener committed
144
                flatten_axis=self.flatten_axis,
145
            )
146
147
148
149
            assert self.scale_inv.shape == unpadded_scale_shape, (
                "Unpadded inverse scale factor has wrong shape, expected"
                f" {unpadded_scale_shape} but got {self.scale_inv.shape}."
            )
150
151
152
153
154
155
156

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
157
        children = (self.data, self.scale_inv, self.amax)
158
159
160
161
162
163
164
165
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
        )
166
167
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
168
169
170
171
    @property
    def ndim(self):
        return self.data.ndim

172
173
174
175
176
177
178
179
    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

180
181
182
183
184
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = self.scaling_mode.get_quantize_layout(usage)
        colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
        rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
185

186
        if colwise_usage_valid or rowwise_usage_valid:
187
188
            return self

189
190
191
192
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" self.is_colwise={self.is_colwise}!"
        )
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
209
210
211
212
213
214
            assert len(logical_axis_names) == self.data.ndim
            flatten_axis = self.data.ndim - self.flatten_axis
            axis_names = (
                *logical_axis_names[flatten_axis:],
                *logical_axis_names[:flatten_axis],
            )
215
216
217
218
219
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

220
        if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
221
222
223
224
225
226
227
228
            # TODO(Phuong): Handle padding !?
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        return ScaledTensor1x(
            data=data,
            scale_inv=scale_inv,
229
            amax=self.amax,
230
231
232
233
234
235
236
237
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
        )

238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
    """Grouped Quantizer for an array.

    This class extends ScaledTensor1x to support quantization of an array in grouped manner,
    where elements are grouped along a specified axis.

    Attributes:
        group_sizes: Array containing the size of each group
        original_shape: The original shape of the tensor before grouping
        group_axis: The axis along which grouping is performed (default: 0)
    """

    group_sizes: jnp.ndarray
    original_shape: Tuple
    group_axis: int

    def __init__(
        self,
        data,
        scale_inv,
261
        amax,
262
263
264
265
266
267
268
269
270
271
        group_sizes,
        scaling_mode,
        dq_dtype,
        _dq_func,
        is_colwise,
        data_layout,
        flatten_axis,
        original_shape,
        group_axis=0,
    ):
272
        self.flatten_axis = flatten_axis
273
274
275
276
        self.group_sizes = group_sizes
        self.original_shape = original_shape
        self.group_axis = group_axis
        super().__init__(
277
278
279
280
281
282
283
284
285
            data,
            scale_inv,
            amax,
            scaling_mode,
            dq_dtype,
            _dq_func,
            is_colwise,
            data_layout,
            flatten_axis,
286
287
288
289
290
        )

    def __post_init__(self):
        assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
        assert self.data.ndim == 1, "Only support flattened data"
291
292
        assert self.group_axis >= 0
        assert self.flatten_axis > 0
293
294
295

        data_ndim = len(self.original_shape)
        assert (
296
297
            0 < self.flatten_axis < data_ndim
        ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
298
299

        assert (
300
301
            0 <= self.group_axis < data_ndim
        ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
302
303
304
305
306
307
308

        expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
            self.original_shape,
            self.group_sizes.size,
            self.group_axis,
            self.is_colwise,
            is_padded=True,
309
            flatten_axis=self.flatten_axis,
310
311
312
313
314
315
316
317
318
319
320
321
322
        )

        assert self.scale_inv.shape == expected_scale_shape, (
            f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
            f" scale_inv, got {self.scale_inv.shape}"
        )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
323
        children = (self.data, self.scale_inv, self.amax, self.group_sizes)
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
            self.original_shape,
            self.group_axis,
        )
        return (children, aux_data)

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        raise NotImplementedError


340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
365
366
367
368
369
    @property
    def ndim(self):
        """Number of dimensions of the underlying row-wise tensor."""
        return self.rowwise_tensor.ndim

370
371
372
373
374
375
376
377
    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

378
379
380
381
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
        q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
382

383
384
        if q_layout_rowwise == QuantizeLayout.ROWWISE:
            return self.rowwise_tensor
385

386
387
        if q_layout_colwise == QuantizeLayout.COLWISE:
            return self.colwise_tensor
388

389
390
391
392
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
        )
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

415
416
417
418
419
420
421
422
423
424
425

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
426
427
        data,
        scale_inv,
428
429
        amax=None,
        scaling_mode=ScalingMode.NO_SCALING,
430
431
432
433
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
434
435
436
        group_sizes=None,
        original_shape=None,
        group_axis=0,
437
438
439
440
441
442
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
443
            amax: The maximum absolute value of the tensor
444
445
446
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
447
448
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
449
            group_sizes: Array of ints containing the size of each group (default: None)
450
451
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
452
453

        Returns:
454
            A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
455
        """
456
457
458
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)

459
        dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
460

461
        if group_sizes is not None:
462
            flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
463
464
465
            assert (
                original_shape is not None
            ), "original_shape is not given for GroupedScaledTensor1x"
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

            # Handling attrs of transposed tensors
            group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
            if data_layout == "T":
                if original_shape[0] == group_sizes.size:
                    original_shape = (
                        original_shape[0],
                        *original_shape[flatten_axis:],
                        *original_shape[1:flatten_axis],
                    )
                    flatten_axis = len(original_shape) - flatten_axis + 1
                else:
                    original_shape = (
                        *original_shape[flatten_axis:],
                        *original_shape[:flatten_axis],
                    )
                    group_axis = flatten_axis
                    flatten_axis = len(original_shape) - flatten_axis

485
486
487
            return GroupedScaledTensor1x(
                data=data,
                scale_inv=scale_inv,
488
                amax=amax,
489
490
491
492
493
494
495
496
497
498
499
                scaling_mode=scaling_mode,
                dq_dtype=dq_dtype,
                _dq_func=dequantizer.grouped_dequantize,
                is_colwise=is_colwise,
                data_layout=data_layout,
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

500
501
502
503
504
        # Handling attrs of transposed tensors
        flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
        if data_layout == "T":
            flatten_axis = data.ndim - flatten_axis

505
        return ScaledTensor1x(
506
507
            data,
            scale_inv,
508
            amax,
509
510
511
512
513
514
            scaling_mode,
            dq_dtype,
            dequantizer.dequantize,
            is_colwise,
            data_layout,
            flatten_axis,
515
        )
516
517
518
519
520
521
522

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
523
524
        amax=None,
        scaling_mode=ScalingMode.NO_SCALING,
525
        dq_dtype=jnp.bfloat16,
526
527
        data_layout="NN",
        flatten_axis=-1,
528
529
530
        group_sizes=None,
        original_shape=None,
        group_axis=0,
531
532
533
534
535
536
537
538
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
539
            amax: The maximum absolute value of the tensor
540
541
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
542
543
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
544
545
546
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
547
548
549
550

        Returns:
            A ScaledTensor2x instance
        """
551
552
553
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)

554
555
        assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
        rowwise_tensor = ScaledTensorFactory.create_1x(
556
557
            data,
            scale_inv,
558
            amax,
559
560
561
            scaling_mode,
            dq_dtype,
            is_colwise=False,
562
563
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
564
565
566
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
567
        )
568
        colwise_tensor = ScaledTensorFactory.create_1x(
569
570
            colwise_data,
            colwise_scale_inv,
571
            amax,
572
573
574
            scaling_mode,
            dq_dtype,
            is_colwise=True,
575
576
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
577
578
579
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
580
581
582
583
584
585
586
587
588
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
589
590
        amax=None,
        scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
591
        dq_dtype: jnp.dtype = jnp.bfloat16,
592
593
594
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
595
596
597
        group_sizes: jnp.ndarray = None,
        original_shape: Tuple[int] = None,
        group_axis: int = 0,
598
599
600
601
602
603
604
605
606
607
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
608
609
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
610
611
612
613
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
614
615

        Returns:
616
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
617
        """
618
        if q_layout == QuantizeLayout.ROWWISE_COLWISE:
619
620
621
622
623
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
624
                amax,
625
626
                scaling_mode,
                dq_dtype,
627
628
                data_layout=data_layout,
                flatten_axis=flatten_axis,
629
630
631
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
632
633
            )

634
        is_colwise = q_layout == QuantizeLayout.COLWISE
635
636
637
638
        if is_colwise:
            return ScaledTensorFactory.create_1x(
                colwise_data,
                colwise_scale_inv,
639
                amax,
640
641
642
643
644
645
646
647
648
649
                scaling_mode,
                dq_dtype,
                is_colwise=is_colwise,
                data_layout=data_layout[0],
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

650
        return ScaledTensorFactory.create_1x(
651
652
            data,
            scale_inv,
653
            amax,
654
655
656
657
658
            scaling_mode,
            dq_dtype,
            is_colwise=is_colwise,
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
659
660
661
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
662
663
664
665
666
667
668
669
670
671
672
673
674
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
675
676
677
    if isinstance(x, GroupedScaledTensor1x):
        raise NotImplementedError

678
679
    if isinstance(x, ScaledTensor):
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
680
681

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)