tensor.py 21.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

18
from transformer_engine_jax import QuantizeLayout
19

Alp Dener's avatar
Alp Dener committed
20
from .helper import apply_padding_to_scale_inv
21
from .scaling_modes import ScalingMode, TensorUsage
22
from .dequantizer import ScalingModeToDequantizerMap
23
24
25
26
27
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
28
    "TensorUsage",
29
30
31
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
32
    "GroupedScaledTensor1x",
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@register_pytree_node_class
@dataclass
class ScaledTensor(ABC):
    """Abstract base class for scaled tensors.

    This class defines the interface for all scaled tensor implementations,
    providing methods for dequantization and accessing row/column-wise components.
    """

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

Alp Dener's avatar
Alp Dener committed
60
61
62
63
64
    @property
    @abstractmethod
    def ndim(self):
        """Number of dimensions of the underlying quantized array."""

65
66
67
68
69
70
71
72
73
    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
74
75
76
    def get_tensor(self, usage: TensorUsage):
        """Returns the appropriate tensor based on the tensor usage and the scaling mode.
        If the tensor usage is not valid for the scaling mode, an error is raised.
77

78
79
        Args:
            usage: The usage of the tensor
80
81

        Returns:
82
            The tensor based on the usage
83
84
        """

85
86
87
88
89
90
91
92
93
94
95
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

@register_pytree_node_class
@dataclass
class ScaledTensor1x(ScaledTensor):
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
112
113
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
114
115
116
117
118
119
120
121
    """

    data: jnp.ndarray
    scale_inv: jnp.ndarray
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
122
    data_layout: str
123
    flatten_axis: int
124
125
126
127
128
129
130

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
131
        assert self.flatten_axis > 0
132
        assert (
133
134
            0 < self.flatten_axis < len(self.data.shape)
        ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
135

Alp Dener's avatar
Alp Dener committed
136
137
138
139
140
141
142
143
144
145
        if self.scaling_mode == ScalingMode.NO_SCALING:
            self.scale_inv = jnp.empty((0,), dtype=jnp.float32)

        else:
            self.scale_inv = apply_padding_to_scale_inv(
                self.scale_inv,
                self.scaling_mode,
                self.data.shape,
                is_colwise=self.is_colwise,
                flatten_axis=self.flatten_axis,
146
147
148
149
150
151
152
153
154
            )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.scale_inv)
155
156
157
158
159
160
161
162
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
        )
163
164
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
165
166
167
168
    @property
    def ndim(self):
        return self.data.ndim

169
170
171
172
173
174
175
176
    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

177
178
179
180
181
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = self.scaling_mode.get_quantize_layout(usage)
        colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
        rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
182

183
        if colwise_usage_valid or rowwise_usage_valid:
184
185
            return self

186
187
188
189
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" self.is_colwise={self.is_colwise}!"
        )
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
206
207
208
209
210
211
            assert len(logical_axis_names) == self.data.ndim
            flatten_axis = self.data.ndim - self.flatten_axis
            axis_names = (
                *logical_axis_names[flatten_axis:],
                *logical_axis_names[:flatten_axis],
            )
212
213
214
215
216
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

217
        if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            # TODO(Phuong): Handle padding !?
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        return ScaledTensor1x(
            data=data,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
        )

234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
    """Grouped Quantizer for an array.

    This class extends ScaledTensor1x to support quantization of an array in grouped manner,
    where elements are grouped along a specified axis.

    Attributes:
        group_sizes: Array containing the size of each group
        original_shape: The original shape of the tensor before grouping
        group_axis: The axis along which grouping is performed (default: 0)
    """

    group_sizes: jnp.ndarray
    original_shape: Tuple
    group_axis: int

    def __init__(
        self,
        data,
        scale_inv,
        group_sizes,
        scaling_mode,
        dq_dtype,
        _dq_func,
        is_colwise,
        data_layout,
        flatten_axis,
        original_shape,
        group_axis=0,
    ):
267
        self.flatten_axis = flatten_axis
268
269
270
271
272
273
274
275
276
277
        self.group_sizes = group_sizes
        self.original_shape = original_shape
        self.group_axis = group_axis
        super().__init__(
            data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis
        )

    def __post_init__(self):
        assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
        assert self.data.ndim == 1, "Only support flattened data"
278
279
        assert self.group_axis >= 0
        assert self.flatten_axis > 0
280
281
282

        data_ndim = len(self.original_shape)
        assert (
283
284
            0 < self.flatten_axis < data_ndim
        ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
285
286

        assert (
287
288
            0 <= self.group_axis < data_ndim
        ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
289
290
291
292
293
294
295

        expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
            self.original_shape,
            self.group_sizes.size,
            self.group_axis,
            self.is_colwise,
            is_padded=True,
296
            flatten_axis=self.flatten_axis,
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        )

        assert self.scale_inv.shape == expected_scale_shape, (
            f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
            f" scale_inv, got {self.scale_inv.shape}"
        )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.scale_inv, self.group_sizes)
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
            self.original_shape,
            self.group_axis,
        )
        return (children, aux_data)

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        raise NotImplementedError


327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
@register_pytree_node_class
@dataclass
class ScaledTensor2x(ScaledTensor):
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
352
353
354
355
356
    @property
    def ndim(self):
        """Number of dimensions of the underlying row-wise tensor."""
        return self.rowwise_tensor.ndim

357
358
359
360
361
362
363
364
    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

365
366
367
368
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
        q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
369

370
371
        if q_layout_rowwise == QuantizeLayout.ROWWISE:
            return self.rowwise_tensor
372

373
374
        if q_layout_colwise == QuantizeLayout.COLWISE:
            return self.colwise_tensor
375

376
377
378
379
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
        )
380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

402
403
404
405
406
407
408
409
410
411
412

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
413
414
415
416
417
418
419
        data,
        scale_inv,
        scaling_mode,
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
420
421
422
        group_sizes=None,
        original_shape=None,
        group_axis=0,
423
424
425
426
427
428
429
430
431
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
432
433
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
434
435
436
            group_sizes: Arra of ints containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
437
438

        Returns:
439
            A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
440
        """
441
        dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
442

443
        if group_sizes is not None:
444
            flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
445
446
447
            assert (
                original_shape is not None
            ), "original_shape is not given for GroupedScaledTensor1x"
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

            # Handling attrs of transposed tensors
            group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
            if data_layout == "T":
                if original_shape[0] == group_sizes.size:
                    original_shape = (
                        original_shape[0],
                        *original_shape[flatten_axis:],
                        *original_shape[1:flatten_axis],
                    )
                    flatten_axis = len(original_shape) - flatten_axis + 1
                else:
                    original_shape = (
                        *original_shape[flatten_axis:],
                        *original_shape[:flatten_axis],
                    )
                    group_axis = flatten_axis
                    flatten_axis = len(original_shape) - flatten_axis

467
468
469
470
471
472
473
474
475
476
477
478
479
480
            return GroupedScaledTensor1x(
                data=data,
                scale_inv=scale_inv,
                scaling_mode=scaling_mode,
                dq_dtype=dq_dtype,
                _dq_func=dequantizer.grouped_dequantize,
                is_colwise=is_colwise,
                data_layout=data_layout,
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

481
482
483
484
485
        # Handling attrs of transposed tensors
        flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
        if data_layout == "T":
            flatten_axis = data.ndim - flatten_axis

486
        return ScaledTensor1x(
487
488
489
490
491
492
493
494
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            dequantizer.dequantize,
            is_colwise,
            data_layout,
            flatten_axis,
495
        )
496
497
498
499
500
501
502
503
504

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
        scaling_mode,
        dq_dtype=jnp.bfloat16,
505
506
        data_layout="NN",
        flatten_axis=-1,
507
508
509
        group_sizes=None,
        original_shape=None,
        group_axis=0,
510
511
512
513
514
515
516
517
518
519
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
520
521
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
522
523
524
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
525
526
527
528

        Returns:
            A ScaledTensor2x instance
        """
529
530
        assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
        rowwise_tensor = ScaledTensorFactory.create_1x(
531
532
533
534
535
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=False,
536
537
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
538
539
540
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
541
        )
542
        colwise_tensor = ScaledTensorFactory.create_1x(
543
544
545
546
547
            colwise_data,
            colwise_scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=True,
548
549
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
550
551
552
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
553
554
555
556
557
558
559
560
561
562
563
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
        scaling_mode: ScalingMode,
        dq_dtype: jnp.dtype = jnp.bfloat16,
564
565
566
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
567
568
569
        group_sizes: jnp.ndarray = None,
        original_shape: Tuple[int] = None,
        group_axis: int = 0,
570
571
572
573
574
575
576
577
578
579
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
580
581
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
582
583
584
585
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
586
587

        Returns:
588
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
589
        """
590
        if q_layout == QuantizeLayout.ROWWISE_COLWISE:
591
592
593
594
595
596
597
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
                scaling_mode,
                dq_dtype,
598
599
                data_layout=data_layout,
                flatten_axis=flatten_axis,
600
601
602
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
603
604
            )

605
        is_colwise = q_layout == QuantizeLayout.COLWISE
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        if is_colwise:
            return ScaledTensorFactory.create_1x(
                colwise_data,
                colwise_scale_inv,
                scaling_mode,
                dq_dtype,
                is_colwise=is_colwise,
                data_layout=data_layout[0],
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

620
        return ScaledTensorFactory.create_1x(
621
622
623
624
625
626
627
            data,
            scale_inv,
            scaling_mode,
            dq_dtype,
            is_colwise=is_colwise,
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
628
629
630
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
631
632
633
634
635
636
637
638
639
640
641
642
643
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
644
645
646
    if isinstance(x, GroupedScaledTensor1x):
        raise NotImplementedError

647
648
    if isinstance(x, ScaledTensor):
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
649
650

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)