tensor.py 27.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

18
from transformer_engine_jax import QuantizeLayout
19

20
from .scaling_modes import ScalingMode, TensorUsage
21
from .dequantizer import ScalingModeToDequantizerMap
22
23
24
25
26
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
27
    "TensorUsage",
28
29
    "AbstractBaseTensor",
    "NoScaleTensor",
30
31
32
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
33
    "GroupedScaledTensor1x",
34
35
36
37
38
39
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@dataclass
40
41
class AbstractBaseTensor(ABC):
    """Abstract base class for all tensor types."""
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

Alp Dener's avatar
Alp Dener committed
56
57
58
59
60
    @property
    @abstractmethod
    def ndim(self):
        """Number of dimensions of the underlying quantized array."""

61
62
63
64
65
66
67
68
69
    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
70
71
72
    def get_tensor(self, usage: TensorUsage):
        """Returns the appropriate tensor based on the tensor usage and the scaling mode.
        If the tensor usage is not valid for the scaling mode, an error is raised.
73

74
75
        Args:
            usage: The usage of the tensor
76
77

        Returns:
78
            The tensor based on the usage
79
80
        """

81
82
83
84
85
86
87
88
89
90
91
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

92

93
94
95
96
97
98
99
100
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
    """Abstract base class for single layout tensors."""

    data: jnp.ndarray
    amax: jnp.ndarray


101
102
@register_pytree_node_class
@dataclass
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class NoScaleTensor(AbstractBaseTensor1x):
    """Higher-precision tensor."""

    def __post_init__(self):
        assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray."

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.amax)
        aux_data = ()
        return (children, aux_data)

    @property
    def ndim(self):
        """Number of dimensions of the underlying array."""
        return self.data.ndim

    def dequantize(self):
        """This is a no-op for a higher-precision tensor so this simply returns the tensor's data."""
        return self.data

    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
        assert (
            q_layout == QuantizeLayout.ROWWISE
        ), "Only ROWWISE layout is supported for NoScaleTensor"
        return self

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names)

        return NoScaleTensor(
            data=data,
            amax=self.amax,
        )


class ScaledTensor(ABC):
    """Abstract base class for scaled tensors."""


@register_pytree_node_class
@dataclass
class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
163
164
165
166
167
168
169
170
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
171
        amax: The maximum absolute value of the tensor
172
173
174
175
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
176
177
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
178
        has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
179
180
181
182
183
184
185
    """

    scale_inv: jnp.ndarray
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
186
    data_layout: str
187
    flatten_axis: int
188
    has_rht_applied: bool
189
190
191
192
193
194
195

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
196
        assert self.flatten_axis > 0
197
        assert (
198
199
            0 < self.flatten_axis < len(self.data.shape)
        ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
200

Alp Dener's avatar
Alp Dener committed
201
202
203
        if self.scaling_mode == ScalingMode.NO_SCALING:
            self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
        else:
204
            unpadded_scale_shape = self.scaling_mode.get_scale_shape(
Alp Dener's avatar
Alp Dener committed
205
                self.data.shape,
206
                data_layout=self.data_layout,
Alp Dener's avatar
Alp Dener committed
207
                is_colwise=self.is_colwise,
208
                is_padded=False,
209
210
211
212
213
214
                # expect the flatten_axis wrt the N layout
                flatten_axis=(
                    self.flatten_axis
                    if self.data_layout == "N"
                    else self.data.ndim - self.flatten_axis
                ),
215
            )
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
            unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
                self.data.shape,
                data_layout=self.data_layout,
                is_colwise=self.is_colwise,
                is_padded=False,
                # expect the flatten_axis wrt the N layout
                flatten_axis=(
                    self.flatten_axis
                    if self.data_layout == "N"
                    else self.data.ndim - self.flatten_axis
                ),
                broadcast_2d_scale_shape_to_1d=True,
            )
            assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
                f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
                f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
232
            )
233
234
235
236
237
238
239

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
240
        children = (self.data, self.amax, self.scale_inv)
241
242
243
244
245
246
247
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
248
            self.has_rht_applied,
249
        )
250
251
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
252
253
254
255
    @property
    def ndim(self):
        return self.data.ndim

256
257
258
259
260
261
262
263
    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

264
265
266
267
268
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = self.scaling_mode.get_quantize_layout(usage)
        colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise
        rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise
269

270
        if colwise_usage_valid or rowwise_usage_valid:
271
272
            return self

273
274
275
276
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" self.is_colwise={self.is_colwise}!"
        )
277

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
293
294
295
296
297
298
            assert len(logical_axis_names) == self.data.ndim
            flatten_axis = self.data.ndim - self.flatten_axis
            axis_names = (
                *logical_axis_names[flatten_axis:],
                *logical_axis_names[:flatten_axis],
            )
299
300
301
302
303
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

304
        if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
305
306
307
308
309
310
311
312
            # TODO(Phuong): Handle padding !?
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        return ScaledTensor1x(
            data=data,
            scale_inv=scale_inv,
313
            amax=self.amax,
314
315
316
317
318
319
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
320
            has_rht_applied=self.has_rht_applied,
321
322
        )

323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
    """Grouped Quantizer for an array.

    This class extends ScaledTensor1x to support quantization of an array in grouped manner,
    where elements are grouped along a specified axis.

    Attributes:
        group_sizes: Array containing the size of each group
        original_shape: The original shape of the tensor before grouping
        group_axis: The axis along which grouping is performed (default: 0)
    """

    group_sizes: jnp.ndarray
    original_shape: Tuple
    group_axis: int

    def __init__(
        self,
        data,
        scale_inv,
346
        amax,
347
348
349
350
351
352
353
354
355
356
        group_sizes,
        scaling_mode,
        dq_dtype,
        _dq_func,
        is_colwise,
        data_layout,
        flatten_axis,
        original_shape,
        group_axis=0,
    ):
357
        self.flatten_axis = flatten_axis
358
359
360
        self.group_sizes = group_sizes
        self.original_shape = original_shape
        self.group_axis = group_axis
361
        # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
362
        super().__init__(
363
364
365
366
367
368
369
370
371
            data=data,
            scale_inv=scale_inv,
            amax=amax,
            scaling_mode=scaling_mode,
            dq_dtype=dq_dtype,
            _dq_func=_dq_func,
            is_colwise=is_colwise,
            data_layout=data_layout,
            flatten_axis=flatten_axis,
372
            has_rht_applied=False,
373
374
375
376
377
        )

    def __post_init__(self):
        assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
        assert self.data.ndim == 1, "Only support flattened data"
378
379
        assert self.group_axis >= 0
        assert self.flatten_axis > 0
380
381
382

        data_ndim = len(self.original_shape)
        assert (
383
384
            0 < self.flatten_axis < data_ndim
        ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
385
386

        assert (
387
388
            0 <= self.group_axis < data_ndim
        ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
389
390
391
392
393
394
395

        expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
            self.original_shape,
            self.group_sizes.size,
            self.group_axis,
            self.is_colwise,
            is_padded=True,
396
            flatten_axis=self.flatten_axis,
397
398
399
400
401
402
403
404
405
406
407
408
409
        )

        assert self.scale_inv.shape == expected_scale_shape, (
            f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
            f" scale_inv, got {self.scale_inv.shape}"
        )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
410
        children = (self.data, self.scale_inv, self.amax, self.group_sizes)
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
            self.original_shape,
            self.group_axis,
        )
        return (children, aux_data)

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        raise NotImplementedError


427
428
@register_pytree_node_class
@dataclass
429
class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
452
453
454
455
456
    @property
    def ndim(self):
        """Number of dimensions of the underlying row-wise tensor."""
        return self.rowwise_tensor.ndim

457
458
459
460
461
462
463
464
    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

465
466
467
468
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
        q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
469

470
471
        if q_layout_rowwise == QuantizeLayout.ROWWISE:
            return self.rowwise_tensor
472

473
474
        if q_layout_colwise == QuantizeLayout.COLWISE:
            return self.colwise_tensor
475

476
477
478
479
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
        )
480

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

502
503
504
505
506
507
508
509
510
511
512

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
513
514
        data,
        scale_inv,
515
516
        amax=None,
        scaling_mode=ScalingMode.NO_SCALING,
517
518
519
520
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
521
522
523
        group_sizes=None,
        original_shape=None,
        group_axis=0,
524
        has_rht_applied=False,
525
526
527
528
529
530
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
531
            amax: The maximum absolute value of the tensor
532
533
534
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
535
536
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
537
            group_sizes: Array of ints containing the size of each group (default: None)
538
539
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
540
            has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
541
542

        Returns:
543
            A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
544
        """
545
546
547
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)

548
        dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
549

550
        if group_sizes is not None:
551
            flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
552
553
554
            assert (
                original_shape is not None
            ), "original_shape is not given for GroupedScaledTensor1x"
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

            # Handling attrs of transposed tensors
            group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
            if data_layout == "T":
                if original_shape[0] == group_sizes.size:
                    original_shape = (
                        original_shape[0],
                        *original_shape[flatten_axis:],
                        *original_shape[1:flatten_axis],
                    )
                    flatten_axis = len(original_shape) - flatten_axis + 1
                else:
                    original_shape = (
                        *original_shape[flatten_axis:],
                        *original_shape[:flatten_axis],
                    )
                    group_axis = flatten_axis
                    flatten_axis = len(original_shape) - flatten_axis

574
575
576
            return GroupedScaledTensor1x(
                data=data,
                scale_inv=scale_inv,
577
                amax=amax,
578
579
580
581
582
583
584
585
586
587
588
                scaling_mode=scaling_mode,
                dq_dtype=dq_dtype,
                _dq_func=dequantizer.grouped_dequantize,
                is_colwise=is_colwise,
                data_layout=data_layout,
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

589
590
591
592
593
        # Handling attrs of transposed tensors
        flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
        if data_layout == "T":
            flatten_axis = data.ndim - flatten_axis

594
        return ScaledTensor1x(
595
596
597
598
599
600
601
602
603
            data=data,
            scale_inv=scale_inv,
            amax=amax,
            scaling_mode=scaling_mode,
            dq_dtype=dq_dtype,
            _dq_func=dequantizer.dequantize,
            is_colwise=is_colwise,
            data_layout=data_layout,
            flatten_axis=flatten_axis,
604
            has_rht_applied=has_rht_applied,
605
        )
606
607
608
609
610
611
612

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
613
        amax=None,
614
        colwise_amax=None,
615
        scaling_mode=ScalingMode.NO_SCALING,
616
        dq_dtype=jnp.bfloat16,
617
618
        data_layout="NN",
        flatten_axis=-1,
619
620
621
        group_sizes=None,
        original_shape=None,
        group_axis=0,
622
623
        rowwise_has_rht_applied=False,
        colwise_has_rht_applied=False,
624
625
626
627
628
629
630
631
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
632
            amax: The maximum absolute value of the tensor
633
634
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
635
636
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
637
638
639
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
640
641
            rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
            colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
642
643
644
645

        Returns:
            A ScaledTensor2x instance
        """
646
647
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)
648
649
        if colwise_amax is None:
            colwise_amax = amax
650

651
652
        assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
        rowwise_tensor = ScaledTensorFactory.create_1x(
653
654
            data,
            scale_inv,
655
            amax,
656
657
658
            scaling_mode,
            dq_dtype,
            is_colwise=False,
659
660
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
661
662
663
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
664
            has_rht_applied=rowwise_has_rht_applied,
665
        )
666
        colwise_tensor = ScaledTensorFactory.create_1x(
667
668
            colwise_data,
            colwise_scale_inv,
669
            colwise_amax,
670
671
            scaling_mode,
            dq_dtype,
672
            is_colwise=True,  # TODO(Phuong): set this correctly
673
674
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
675
676
677
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
678
            has_rht_applied=colwise_has_rht_applied,
679
680
681
682
683
684
685
686
687
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
688
        amax=None,
689
        colwise_amax=None,
690
        scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
691
        dq_dtype: jnp.dtype = jnp.bfloat16,
692
693
694
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
695
696
697
        group_sizes: jnp.ndarray = None,
        original_shape: Tuple[int] = None,
        group_axis: int = 0,
698
699
        rowwise_has_rht_applied: bool = False,
        colwise_has_rht_applied: bool = False,
700
701
702
703
704
705
706
707
708
709
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
710
711
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
712
713
714
715
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
716
717
            rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
            colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
718
719

        Returns:
720
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
721
        """
722
723
        assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"

724
        if q_layout == QuantizeLayout.ROWWISE_COLWISE:
725
726
727
728
729
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
730
                amax,
731
                colwise_amax,
732
733
                scaling_mode,
                dq_dtype,
734
735
                data_layout=data_layout,
                flatten_axis=flatten_axis,
736
737
738
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
739
740
                rowwise_has_rht_applied=rowwise_has_rht_applied,
                colwise_has_rht_applied=colwise_has_rht_applied,
741
742
            )

743
        is_colwise = q_layout == QuantizeLayout.COLWISE
744
745
746
747
        if is_colwise:
            return ScaledTensorFactory.create_1x(
                colwise_data,
                colwise_scale_inv,
748
                colwise_amax if colwise_amax is not None else amax,
749
750
751
752
753
754
755
756
                scaling_mode,
                dq_dtype,
                is_colwise=is_colwise,
                data_layout=data_layout[0],
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
757
                has_rht_applied=colwise_has_rht_applied,
758
759
            )

760
        return ScaledTensorFactory.create_1x(
761
762
            data,
            scale_inv,
763
            amax,
764
765
766
767
768
            scaling_mode,
            dq_dtype,
            is_colwise=is_colwise,
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
769
770
771
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
772
            has_rht_applied=rowwise_has_rht_applied,
773
774
775
776
777
778
779
780
781
782
783
784
785
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
786
787
788
    if isinstance(x, GroupedScaledTensor1x):
        raise NotImplementedError

789
    if isinstance(x, AbstractBaseTensor):
790
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
791
792

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)