tensor.py 28.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
17
from jax.ad_checkpoint import checkpoint_name as jax_checkpoint_name
18
19


20
from .scaling_modes import ScalingMode, TensorUsage
21
from .dequantizer import ScalingModeToDequantizerMap
22
from .misc import QuantizeLayout
23
24
25
26
27
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
28
    "TensorUsage",
29
30
    "AbstractBaseTensor",
    "NoScaleTensor",
31
32
33
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
34
    "GroupedScaledTensor1x",
35
36
37
38
39
40
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@dataclass
41
42
class AbstractBaseTensor(ABC):
    """Abstract base class for all tensor types."""
43
44
45
46
47
48
49
50
51
52
53
54
55
56

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

Alp Dener's avatar
Alp Dener committed
57
58
59
60
61
    @property
    @abstractmethod
    def ndim(self):
        """Number of dimensions of the underlying quantized array."""

62
63
64
65
66
67
68
69
70
    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
71
72
73
    def get_tensor(self, usage: TensorUsage):
        """Returns the appropriate tensor based on the tensor usage and the scaling mode.
        If the tensor usage is not valid for the scaling mode, an error is raised.
74

75
76
        Args:
            usage: The usage of the tensor
77
78

        Returns:
79
            The tensor based on the usage
80
81
        """

82
83
84
85
86
87
88
89
90
91
92
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

93
94
95
96
97
98
99
100
101
102
103
    @abstractmethod
    def checkpoint(self, quantizer):
        """Checkpoints the tensor with the given quantizer's checkpoint name if available.

        Args:
            quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.

        Returns:
            The checkpointed tensor
        """

104

105
106
107
108
109
110
111
112
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
    """Abstract base class for single layout tensors."""

    data: jnp.ndarray
    amax: jnp.ndarray


113
114
@register_pytree_node_class
@dataclass
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class NoScaleTensor(AbstractBaseTensor1x):
    """Higher-precision tensor."""

    def __post_init__(self):
        assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray."

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.amax)
        aux_data = ()
        return (children, aux_data)

    @property
    def ndim(self):
        """Number of dimensions of the underlying array."""
        return self.data.ndim

    def dequantize(self):
        """This is a no-op for a higher-precision tensor so this simply returns the tensor's data."""
        return self.data

    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
143
        assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor"
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        return self

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names)

        return NoScaleTensor(
            data=data,
            amax=self.amax,
        )

165
166
167
168
169
170
171
172
173
174
175
176
    def checkpoint(self, quantizer):
        """Checkpoints the tensor with the given quantizer's checkpoint name if available.

        Args:
            quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.

        Returns:
            The checkpointed tensor
        """
        assert quantizer is None, "NoScaleTensor does not support quantization."
        return self

177
178
179
180
181
182
183
184

class ScaledTensor(ABC):
    """Abstract base class for scaled tensors."""


@register_pytree_node_class
@dataclass
class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
185
186
187
188
189
190
191
192
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
193
        amax: The maximum absolute value of the tensor
194
195
196
197
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
198
199
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
200
        has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
201
202
203
204
205
206
207
    """

    scale_inv: jnp.ndarray
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
208
    data_layout: str
209
    flatten_axis: int
210
    has_rht_applied: bool
211
212
213
214
215
216
217

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
218
        assert self.flatten_axis > 0
219
        assert (
220
221
            0 < self.flatten_axis < len(self.data.shape)
        ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
222

Alp Dener's avatar
Alp Dener committed
223
224
225
        if self.scaling_mode == ScalingMode.NO_SCALING:
            self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
        else:
226
            unpadded_scale_shape = self.scaling_mode.get_scale_shape(
Alp Dener's avatar
Alp Dener committed
227
                self.data.shape,
228
                data_layout=self.data_layout,
Alp Dener's avatar
Alp Dener committed
229
                is_colwise=self.is_colwise,
230
                is_padded=False,
231
232
233
234
235
236
                # expect the flatten_axis wrt the N layout
                flatten_axis=(
                    self.flatten_axis
                    if self.data_layout == "N"
                    else self.data.ndim - self.flatten_axis
                ),
237
            )
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
                self.data.shape,
                data_layout=self.data_layout,
                is_colwise=self.is_colwise,
                is_padded=False,
                # expect the flatten_axis wrt the N layout
                flatten_axis=(
                    self.flatten_axis
                    if self.data_layout == "N"
                    else self.data.ndim - self.flatten_axis
                ),
                broadcast_2d_scale_shape_to_1d=True,
            )
            assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
                f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
                f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
254
            )
255
256
257
258
259
260
261

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
262
        children = (self.data, self.amax, self.scale_inv)
263
264
265
266
267
268
269
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
270
            self.has_rht_applied,
271
        )
272
273
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
274
275
276
277
    @property
    def ndim(self):
        return self.data.ndim

278
279
280
281
282
283
284
285
    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

286
287
288
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = self.scaling_mode.get_quantize_layout(usage)
289
290
        colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise
        rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise
291

292
        if colwise_usage_valid or rowwise_usage_valid:
293
294
            return self

295
296
297
298
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" self.is_colwise={self.is_colwise}!"
        )
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
315
316
317
318
319
320
            assert len(logical_axis_names) == self.data.ndim
            flatten_axis = self.data.ndim - self.flatten_axis
            axis_names = (
                *logical_axis_names[flatten_axis:],
                *logical_axis_names[:flatten_axis],
            )
321
322
323
324
325
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

326
        if self.scaling_mode.is_block_scaling:  # Both MXFP8 and NVFP4
327
328
329
330
331
332
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        return ScaledTensor1x(
            data=data,
333
            amax=self.amax,
334
            scale_inv=scale_inv,
335
336
337
338
339
340
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
341
            has_rht_applied=self.has_rht_applied,
342
343
        )

344
345
346
347
348
349
350
351
352
353
354
355
356
357
    def checkpoint(self, quantizer):
        """Checkpoints the tensor with the given quantizer's checkpoint name if available.

        Args:
            quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.

        Returns:
            The checkpointed tensor
        """
        if quantizer is None or quantizer.checkpoint_name is None:
            return self

        return jax_checkpoint_name(self, name=quantizer.checkpoint_name)

358

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
    """Grouped Quantizer for an array.

    This class extends ScaledTensor1x to support quantization of an array in grouped manner,
    where elements are grouped along a specified axis.

    Attributes:
        group_sizes: Array containing the size of each group
        original_shape: The original shape of the tensor before grouping
        group_axis: The axis along which grouping is performed (default: 0)
    """

    group_sizes: jnp.ndarray
    original_shape: Tuple
    group_axis: int

    def __init__(
        self,
        data,
        scale_inv,
381
        amax,
382
383
384
385
386
387
388
389
390
391
        group_sizes,
        scaling_mode,
        dq_dtype,
        _dq_func,
        is_colwise,
        data_layout,
        flatten_axis,
        original_shape,
        group_axis=0,
    ):
392
        self.flatten_axis = flatten_axis
393
394
395
        self.group_sizes = group_sizes
        self.original_shape = original_shape
        self.group_axis = group_axis
396
        # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
397
        super().__init__(
398
399
400
401
402
403
404
405
406
            data=data,
            scale_inv=scale_inv,
            amax=amax,
            scaling_mode=scaling_mode,
            dq_dtype=dq_dtype,
            _dq_func=_dq_func,
            is_colwise=is_colwise,
            data_layout=data_layout,
            flatten_axis=flatten_axis,
407
            has_rht_applied=False,
408
409
410
411
412
        )

    def __post_init__(self):
        assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
        assert self.data.ndim == 1, "Only support flattened data"
413
414
        assert self.group_axis >= 0
        assert self.flatten_axis > 0
415
416
417

        data_ndim = len(self.original_shape)
        assert (
418
419
            0 < self.flatten_axis < data_ndim
        ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
420
421

        assert (
422
423
            0 <= self.group_axis < data_ndim
        ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
424
425
426
427
428
429
430

        expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
            self.original_shape,
            self.group_sizes.size,
            self.group_axis,
            self.is_colwise,
            is_padded=True,
431
            flatten_axis=self.flatten_axis,
432
433
434
435
436
437
438
439
440
441
442
443
444
        )

        assert self.scale_inv.shape == expected_scale_shape, (
            f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
            f" scale_inv, got {self.scale_inv.shape}"
        )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
445
        children = (self.data, self.scale_inv, self.amax, self.group_sizes)
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
            self.original_shape,
            self.group_axis,
        )
        return (children, aux_data)

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        raise NotImplementedError

461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def checkpoint(self, quantizer):
        """Checkpoints the tensor with the given quantizer's checkpoint name if available.

        Args:
            quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied.

        Returns:
            The checkpointed tensor
        """
        if quantizer is None or quantizer.checkpoint_name is None:
            return self

        return jax_checkpoint_name(self, name=quantizer.checkpoint_name)

475

476
477
@register_pytree_node_class
@dataclass
478
class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
501
502
503
504
505
    @property
    def ndim(self):
        """Number of dimensions of the underlying row-wise tensor."""
        return self.rowwise_tensor.ndim

506
507
508
509
510
511
512
513
    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

514
515
516
517
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
        q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
518

519
        if q_layout_rowwise.is_rowwise_only:
520
            return self.rowwise_tensor
521

522
        if q_layout_colwise.is_colwise_only:
523
            return self.colwise_tensor
524

525
526
527
528
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
        )
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

551
552
553
    def checkpoint(self, quantizer):
        raise NotImplementedError

554
555
556
557
558
559
560
561
562
563
564

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
565
566
        data,
        scale_inv,
567
568
        amax=None,
        scaling_mode=ScalingMode.NO_SCALING,
569
570
571
572
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
573
574
575
        group_sizes=None,
        original_shape=None,
        group_axis=0,
576
        has_rht_applied=False,
577
578
579
580
581
582
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
583
            amax: The maximum absolute value of the tensor
584
585
586
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
587
588
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
589
            group_sizes: Array of ints containing the size of each group (default: None)
590
591
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
592
            has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
593
594

        Returns:
595
            A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
596
        """
597
598
599
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)

600
        dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
601

602
        if group_sizes is not None:
603
            flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape)
604
605
606
            assert (
                original_shape is not None
            ), "original_shape is not given for GroupedScaledTensor1x"
607
608

            # Handling attrs of transposed tensors
609
            group_axis = (len(original_shape) + group_axis) % len(original_shape)
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
            if data_layout == "T":
                if original_shape[0] == group_sizes.size:
                    original_shape = (
                        original_shape[0],
                        *original_shape[flatten_axis:],
                        *original_shape[1:flatten_axis],
                    )
                    flatten_axis = len(original_shape) - flatten_axis + 1
                else:
                    original_shape = (
                        *original_shape[flatten_axis:],
                        *original_shape[:flatten_axis],
                    )
                    group_axis = flatten_axis
                    flatten_axis = len(original_shape) - flatten_axis

626
627
628
            return GroupedScaledTensor1x(
                data=data,
                scale_inv=scale_inv,
629
                amax=amax,
630
631
632
633
634
635
636
637
638
639
640
                scaling_mode=scaling_mode,
                dq_dtype=dq_dtype,
                _dq_func=dequantizer.grouped_dequantize,
                is_colwise=is_colwise,
                data_layout=data_layout,
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

641
        # Handling attrs of transposed tensors
642
        flatten_axis = (data.ndim + flatten_axis) % data.ndim
643
644
645
        if data_layout == "T":
            flatten_axis = data.ndim - flatten_axis

646
        return ScaledTensor1x(
647
648
649
650
651
652
653
654
655
            data=data,
            scale_inv=scale_inv,
            amax=amax,
            scaling_mode=scaling_mode,
            dq_dtype=dq_dtype,
            _dq_func=dequantizer.dequantize,
            is_colwise=is_colwise,
            data_layout=data_layout,
            flatten_axis=flatten_axis,
656
            has_rht_applied=has_rht_applied,
657
        )
658
659
660
661
662
663
664

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
665
        amax=None,
666
        colwise_amax=None,
667
        scaling_mode=ScalingMode.NO_SCALING,
668
        dq_dtype=jnp.bfloat16,
669
670
        data_layout="NN",
        flatten_axis=-1,
671
672
673
        group_sizes=None,
        original_shape=None,
        group_axis=0,
674
675
        rowwise_has_rht_applied=False,
        colwise_has_rht_applied=False,
676
677
678
679
680
681
682
683
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
684
            amax: The maximum absolute value of the tensor
685
686
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
687
688
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
689
690
691
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
692
693
            rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
            colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
694
695
696
697

        Returns:
            A ScaledTensor2x instance
        """
698
699
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)
700
701
        if colwise_amax is None:
            colwise_amax = amax
702

703
704
        assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
        rowwise_tensor = ScaledTensorFactory.create_1x(
705
706
            data,
            scale_inv,
707
            amax,
708
709
710
            scaling_mode,
            dq_dtype,
            is_colwise=False,
711
712
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
713
714
715
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
716
            has_rht_applied=rowwise_has_rht_applied,
717
        )
718
        colwise_tensor = ScaledTensorFactory.create_1x(
719
720
            colwise_data,
            colwise_scale_inv,
721
            colwise_amax,
722
723
            scaling_mode,
            dq_dtype,
724
            is_colwise=True,
725
726
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
727
728
729
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
730
            has_rht_applied=colwise_has_rht_applied,
731
732
733
734
735
736
737
738
739
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
740
        amax=None,
741
        colwise_amax=None,
742
        scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
743
        dq_dtype: jnp.dtype = jnp.bfloat16,
744
745
746
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
747
748
749
        group_sizes: jnp.ndarray = None,
        original_shape: Tuple[int] = None,
        group_axis: int = 0,
750
751
        rowwise_has_rht_applied: bool = False,
        colwise_has_rht_applied: bool = False,
752
753
754
755
756
757
758
759
760
761
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
762
763
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
764
765
766
767
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
768
769
            rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
            colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
770
771

        Returns:
772
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
773
        """
774
775
        assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"

776
        if q_layout.is_rowwise_colwise:
777
778
779
780
781
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
782
                amax,
783
                colwise_amax,
784
785
                scaling_mode,
                dq_dtype,
786
787
                data_layout=data_layout,
                flatten_axis=flatten_axis,
788
789
790
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
791
792
                rowwise_has_rht_applied=rowwise_has_rht_applied,
                colwise_has_rht_applied=colwise_has_rht_applied,
793
794
            )

795
        if q_layout.is_colwise_only:
796
797
798
            return ScaledTensorFactory.create_1x(
                colwise_data,
                colwise_scale_inv,
799
                colwise_amax if colwise_amax is not None else amax,
800
801
                scaling_mode,
                dq_dtype,
802
                is_colwise=True,
803
804
805
806
807
                data_layout=data_layout[0],
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
808
                has_rht_applied=colwise_has_rht_applied,
809
810
            )

811
        return ScaledTensorFactory.create_1x(
812
813
            data,
            scale_inv,
814
            amax,
815
816
            scaling_mode,
            dq_dtype,
817
            is_colwise=False,
818
819
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
820
821
822
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
823
            has_rht_applied=rowwise_has_rht_applied,
824
825
826
827
828
829
830
831
832
833
834
835
836
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
837
838
839
    if isinstance(x, GroupedScaledTensor1x):
        raise NotImplementedError

840
    if isinstance(x, AbstractBaseTensor):
841
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
842
843

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)