tensor.py 27.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor classes for TE/JAX

This module provides tensor classes for handling quantized tensors in JAX, including
both single-scale (1x) and double-scale (2x) quantization schemes. It supports
rowwise and colwise quantization modes with proper scaling and dequantization.
"""
from dataclasses import dataclass
from typing import Callable, Tuple
from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class


19
from .scaling_modes import ScalingMode, TensorUsage
20
from .dequantizer import ScalingModeToDequantizerMap
21
from .misc import QuantizeLayout
22
23
24
25
26
from ..sharding import (
    with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
)

__all__ = [
27
    "TensorUsage",
28
29
    "AbstractBaseTensor",
    "NoScaleTensor",
30
31
32
    "ScaledTensor",
    "ScaledTensor1x",
    "ScaledTensor2x",
33
    "GroupedScaledTensor1x",
34
35
36
37
38
39
    "ScaledTensorFactory",
    "with_sharding_constraint_by_logical_axes",
]


@dataclass
40
41
class AbstractBaseTensor(ABC):
    """Abstract base class for all tensor types."""
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstructs the tensor from its flattened representation.

        Args:
            aux_data: Auxiliary data needed for reconstruction
            children: The flattened tensor components

        Returns:
            A reconstructed tensor instance
        """
        return cls(*children, *aux_data)

Alp Dener's avatar
Alp Dener committed
56
57
58
59
60
    @property
    @abstractmethod
    def ndim(self):
        """Number of dimensions of the underlying quantized array."""

61
62
63
64
65
66
67
68
69
    @abstractmethod
    def dequantize(self):
        """Dequantizes the tensor back to its original precision.

        Returns:
            The dequantized tensor
        """

    @abstractmethod
70
71
72
    def get_tensor(self, usage: TensorUsage):
        """Returns the appropriate tensor based on the tensor usage and the scaling mode.
        If the tensor usage is not valid for the scaling mode, an error is raised.
73

74
75
        Args:
            usage: The usage of the tensor
76
77

        Returns:
78
            The tensor based on the usage
79
80
        """

81
82
83
84
85
86
87
88
89
90
91
    @abstractmethod
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """

92

93
94
95
96
97
98
99
100
@dataclass
class AbstractBaseTensor1x(AbstractBaseTensor):
    """Abstract base class for single layout tensors."""

    data: jnp.ndarray
    amax: jnp.ndarray


101
102
@register_pytree_node_class
@dataclass
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class NoScaleTensor(AbstractBaseTensor1x):
    """Higher-precision tensor."""

    def __post_init__(self):
        assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray."

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.data, self.amax)
        aux_data = ()
        return (children, aux_data)

    @property
    def ndim(self):
        """Number of dimensions of the underlying array."""
        return self.data.ndim

    def dequantize(self):
        """This is a no-op for a higher-precision tensor so this simply returns the tensor's data."""
        return self.data

    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage)
131
        assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor"
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        return self

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names)

        return NoScaleTensor(
            data=data,
            amax=self.amax,
        )


class ScaledTensor(ABC):
    """Abstract base class for scaled tensors."""


@register_pytree_node_class
@dataclass
class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor):
161
162
163
164
165
166
167
168
    """Single-scale quantized tensor implementation.

    This class represents a tensor quantized with a single scaling factor,
    supporting both row-wise and column-wise quantization modes.

    Attributes:
        data: The quantized tensor data
        scale_inv: The inverse scaling factors
169
        amax: The maximum absolute value of the tensor
170
171
172
173
        scaling_mode: The scaling mode used for quantization
        dq_dtype: The data type for dequantized values
        _dq_func: The dequantization function
        is_colwise: Whether the tensor uses column-wise quantization
174
175
        data_layout: The data_layout specification for the tensor
        flatten_axis: The quantization axis for the tensor
176
        has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization
177
178
179
180
181
182
183
    """

    scale_inv: jnp.ndarray
    scaling_mode: ScalingMode
    dq_dtype: jnp.dtype
    _dq_func: Callable
    is_colwise: bool
184
    data_layout: str
185
    flatten_axis: int
186
    has_rht_applied: bool
187
188
189
190
191
192
193

    def __post_init__(self):
        """Validates and adjusts the scale_inv shape after initialization.

        Ensures the scale_inv shape matches the expected shape based on the scaling mode
        and quantization direction. Pads the scale_inv if necessary.
        """
194
        assert self.flatten_axis > 0
195
        assert (
196
197
            0 < self.flatten_axis < len(self.data.shape)
        ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
198

Alp Dener's avatar
Alp Dener committed
199
200
201
        if self.scaling_mode == ScalingMode.NO_SCALING:
            self.scale_inv = jnp.empty((0,), dtype=jnp.float32)
        else:
202
            unpadded_scale_shape = self.scaling_mode.get_scale_shape(
Alp Dener's avatar
Alp Dener committed
203
                self.data.shape,
204
                data_layout=self.data_layout,
Alp Dener's avatar
Alp Dener committed
205
                is_colwise=self.is_colwise,
206
                is_padded=False,
207
208
209
210
211
212
                # expect the flatten_axis wrt the N layout
                flatten_axis=(
                    self.flatten_axis
                    if self.data_layout == "N"
                    else self.data.ndim - self.flatten_axis
                ),
213
            )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape(
                self.data.shape,
                data_layout=self.data_layout,
                is_colwise=self.is_colwise,
                is_padded=False,
                # expect the flatten_axis wrt the N layout
                flatten_axis=(
                    self.flatten_axis
                    if self.data_layout == "N"
                    else self.data.ndim - self.flatten_axis
                ),
                broadcast_2d_scale_shape_to_1d=True,
            )
            assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), (
                f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or"
                f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}."
230
            )
231
232
233
234
235
236
237

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
238
        children = (self.data, self.amax, self.scale_inv)
239
240
241
242
243
244
245
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
246
            self.has_rht_applied,
247
        )
248
249
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
250
251
252
253
    @property
    def ndim(self):
        return self.data.ndim

254
255
256
257
258
259
260
261
    def dequantize(self):
        """Dequantizes the tensor using the stored dequantization function.

        Returns:
            The dequantized tensor
        """
        return self._dq_func(self)

262
263
264
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout = self.scaling_mode.get_quantize_layout(usage)
265
266
        colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise
        rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise
267

268
        if colwise_usage_valid or rowwise_usage_valid:
269
270
            return self

271
272
273
274
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" self.is_colwise={self.is_colwise}!"
        )
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        # axis_names were given for N layout, so needs to be transpose for T layout
        if self.data_layout == "T":
            assert self.flatten_axis > 0
291
292
293
294
295
296
            assert len(logical_axis_names) == self.data.ndim
            flatten_axis = self.data.ndim - self.flatten_axis
            axis_names = (
                *logical_axis_names[flatten_axis:],
                *logical_axis_names[:flatten_axis],
            )
297
298
299
300
301
        else:
            axis_names = logical_axis_names

        data = with_sharding_constraint_by_logical_axes(self.data, axis_names)

302
        if self.scaling_mode.is_block_scaling:  # Both MXFP8 and NVFP4
303
304
305
306
307
308
            scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names)
        else:
            scale_inv = self.scale_inv

        return ScaledTensor1x(
            data=data,
309
            amax=self.amax,
310
            scale_inv=scale_inv,
311
312
313
314
315
316
            scaling_mode=self.scaling_mode,
            dq_dtype=self.dq_dtype,
            _dq_func=self._dq_func,
            is_colwise=self.is_colwise,
            data_layout=self.data_layout,
            flatten_axis=self.flatten_axis,
317
            has_rht_applied=self.has_rht_applied,
318
319
        )

320

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
    """Grouped Quantizer for an array.

    This class extends ScaledTensor1x to support quantization of an array in grouped manner,
    where elements are grouped along a specified axis.

    Attributes:
        group_sizes: Array containing the size of each group
        original_shape: The original shape of the tensor before grouping
        group_axis: The axis along which grouping is performed (default: 0)
    """

    group_sizes: jnp.ndarray
    original_shape: Tuple
    group_axis: int

    def __init__(
        self,
        data,
        scale_inv,
343
        amax,
344
345
346
347
348
349
350
351
352
353
        group_sizes,
        scaling_mode,
        dq_dtype,
        _dq_func,
        is_colwise,
        data_layout,
        flatten_axis,
        original_shape,
        group_axis=0,
    ):
354
        self.flatten_axis = flatten_axis
355
356
357
        self.group_sizes = group_sizes
        self.original_shape = original_shape
        self.group_axis = group_axis
358
        # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4
359
        super().__init__(
360
361
362
363
364
365
366
367
368
            data=data,
            scale_inv=scale_inv,
            amax=amax,
            scaling_mode=scaling_mode,
            dq_dtype=dq_dtype,
            _dq_func=_dq_func,
            is_colwise=is_colwise,
            data_layout=data_layout,
            flatten_axis=flatten_axis,
369
            has_rht_applied=False,
370
371
372
373
374
        )

    def __post_init__(self):
        assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
        assert self.data.ndim == 1, "Only support flattened data"
375
376
        assert self.group_axis >= 0
        assert self.flatten_axis > 0
377
378
379

        data_ndim = len(self.original_shape)
        assert (
380
381
            0 < self.flatten_axis < data_ndim
        ), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
382
383

        assert (
384
385
            0 <= self.group_axis < data_ndim
        ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
386
387
388
389
390
391
392

        expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
            self.original_shape,
            self.group_sizes.size,
            self.group_axis,
            self.is_colwise,
            is_padded=True,
393
            flatten_axis=self.flatten_axis,
394
395
396
397
398
399
400
401
402
403
404
405
406
        )

        assert self.scale_inv.shape == expected_scale_shape, (
            f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
            f" scale_inv, got {self.scale_inv.shape}"
        )

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
407
        children = (self.data, self.scale_inv, self.amax, self.group_sizes)
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        aux_data = (
            self.scaling_mode,
            self.dq_dtype,
            self._dq_func,
            self.is_colwise,
            self.data_layout,
            self.flatten_axis,
            self.original_shape,
            self.group_axis,
        )
        return (children, aux_data)

    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        raise NotImplementedError


424
425
@register_pytree_node_class
@dataclass
426
class ScaledTensor2x(AbstractBaseTensor, ScaledTensor):
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    """Double-scale quantized tensor implementation.

    This class represents a tensor quantized with both row-wise and column-wise scaling factors.

    Attributes:
        rowwise_tensor: The row-wise quantized component
        colwise_tensor: The column-wise quantized component
    """

    rowwise_tensor: ScaledTensor1x
    colwise_tensor: ScaledTensor1x

    def tree_flatten(self):
        """Flattens the tensor for JAX tree operations.

        Returns:
            A tuple containing (children, aux_data) for tree operations
        """
        children = (self.rowwise_tensor, self.colwise_tensor)
        aux_data = ()
        return (children, aux_data)

Alp Dener's avatar
Alp Dener committed
449
450
451
452
453
    @property
    def ndim(self):
        """Number of dimensions of the underlying row-wise tensor."""
        return self.rowwise_tensor.ndim

454
455
456
457
458
459
460
461
    def dequantize(self):
        """Dequantizes the tensor using the row-wise component's dequantization.

        Returns:
            The dequantized tensor
        """
        return self.rowwise_tensor.dequantize()

462
463
464
465
    def get_tensor(self, usage: TensorUsage):
        """Returns the tensor based on the tensor usage."""
        q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage)
        q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage)
466

467
        if q_layout_rowwise.is_rowwise_only:
468
            return self.rowwise_tensor
469

470
        if q_layout_colwise.is_colwise_only:
471
            return self.colwise_tensor
472

473
474
475
476
        raise ValueError(
            f"Calling get_tensor() with usage {usage} is not valid for this tensor as"
            f" q_layout_rowwise={q_layout_rowwise} and q_layout_colwise={q_layout_colwise}!"
        )
477

478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
        """Applies sharding constraints to a tensor based on logical axis names.

        Args:
            logical_axis_names: Tuple of logical axis names for sharding

        Returns:
            The tensor with applied sharding constraints
        """
        if not logical_axis_names:
            return self

        rowwise_tensor = self.rowwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )
        colwise_tensor = self.colwise_tensor.apply_sharding_constraint_by_logical_axes(
            logical_axis_names
        )

        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

499
500
501
502
503
504
505
506
507
508
509

@dataclass
class ScaledTensorFactory:
    """Factory class for creating scaled tensor instances.

    Provides static methods to create both single-scale (1x) and double-scale (2x)
    quantized tensors with various configurations.
    """

    @staticmethod
    def create_1x(
510
511
        data,
        scale_inv,
512
513
        amax=None,
        scaling_mode=ScalingMode.NO_SCALING,
514
515
516
517
        dq_dtype=jnp.bfloat16,
        is_colwise=False,
        data_layout="N",
        flatten_axis=-1,
518
519
520
        group_sizes=None,
        original_shape=None,
        group_axis=0,
521
        has_rht_applied=False,
522
523
524
525
526
527
    ):
        """Creates a single-scale quantized tensor.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
528
            amax: The maximum absolute value of the tensor
529
530
531
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
            is_colwise: Whether to use column-wise quantization (default: False)
532
533
            data_layout: The data_layout specification (default: "N")
            flatten_axis: The quantization axis for the tensor
534
            group_sizes: Array of ints containing the size of each group (default: None)
535
536
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
537
            has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False)
538
539

        Returns:
540
            A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
541
        """
542
543
544
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)

545
        dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
546

547
        if group_sizes is not None:
548
            flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape)
549
550
551
            assert (
                original_shape is not None
            ), "original_shape is not given for GroupedScaledTensor1x"
552
553

            # Handling attrs of transposed tensors
554
            group_axis = (len(original_shape) + group_axis) % len(original_shape)
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
            if data_layout == "T":
                if original_shape[0] == group_sizes.size:
                    original_shape = (
                        original_shape[0],
                        *original_shape[flatten_axis:],
                        *original_shape[1:flatten_axis],
                    )
                    flatten_axis = len(original_shape) - flatten_axis + 1
                else:
                    original_shape = (
                        *original_shape[flatten_axis:],
                        *original_shape[:flatten_axis],
                    )
                    group_axis = flatten_axis
                    flatten_axis = len(original_shape) - flatten_axis

571
572
573
            return GroupedScaledTensor1x(
                data=data,
                scale_inv=scale_inv,
574
                amax=amax,
575
576
577
578
579
580
581
582
583
584
585
                scaling_mode=scaling_mode,
                dq_dtype=dq_dtype,
                _dq_func=dequantizer.grouped_dequantize,
                is_colwise=is_colwise,
                data_layout=data_layout,
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
            )

586
        # Handling attrs of transposed tensors
587
        flatten_axis = (data.ndim + flatten_axis) % data.ndim
588
589
590
        if data_layout == "T":
            flatten_axis = data.ndim - flatten_axis

591
        return ScaledTensor1x(
592
593
594
595
596
597
598
599
600
            data=data,
            scale_inv=scale_inv,
            amax=amax,
            scaling_mode=scaling_mode,
            dq_dtype=dq_dtype,
            _dq_func=dequantizer.dequantize,
            is_colwise=is_colwise,
            data_layout=data_layout,
            flatten_axis=flatten_axis,
601
            has_rht_applied=has_rht_applied,
602
        )
603
604
605
606
607
608
609

    @staticmethod
    def create_2x(
        data,
        scale_inv,
        colwise_data,
        colwise_scale_inv,
610
        amax=None,
611
        colwise_amax=None,
612
        scaling_mode=ScalingMode.NO_SCALING,
613
        dq_dtype=jnp.bfloat16,
614
615
        data_layout="NN",
        flatten_axis=-1,
616
617
618
        group_sizes=None,
        original_shape=None,
        group_axis=0,
619
620
        rowwise_has_rht_applied=False,
        colwise_has_rht_applied=False,
621
622
623
624
625
626
627
628
    ):
        """Creates a double-scale quantized tensor.

        Args:
            data: The row-wise quantized data
            scale_inv: The row-wise inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
629
            amax: The maximum absolute value of the tensor
630
631
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
632
633
            data_layout: The data_layout specification (default: "NN")
            flatten_axis: The quantization axis for the tensor
634
635
636
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
637
638
            rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
            colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
639
640
641
642

        Returns:
            A ScaledTensor2x instance
        """
643
644
        if amax is None:
            amax = jnp.empty((1,), dtype=jnp.float32)
645
646
        if colwise_amax is None:
            colwise_amax = amax
647

648
649
        assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
        rowwise_tensor = ScaledTensorFactory.create_1x(
650
651
            data,
            scale_inv,
652
            amax,
653
654
655
            scaling_mode,
            dq_dtype,
            is_colwise=False,
656
657
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
658
659
660
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
661
            has_rht_applied=rowwise_has_rht_applied,
662
        )
663
        colwise_tensor = ScaledTensorFactory.create_1x(
664
665
            colwise_data,
            colwise_scale_inv,
666
            colwise_amax,
667
668
            scaling_mode,
            dq_dtype,
669
            is_colwise=True,
670
671
            data_layout=data_layout[1],
            flatten_axis=flatten_axis,
672
673
674
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
675
            has_rht_applied=colwise_has_rht_applied,
676
677
678
679
680
681
682
683
684
        )
        return ScaledTensor2x(rowwise_tensor, colwise_tensor)

    @staticmethod
    def create(
        data: jnp.ndarray,
        scale_inv: jnp.ndarray,
        colwise_data: jnp.ndarray,
        colwise_scale_inv: jnp.ndarray,
685
        amax=None,
686
        colwise_amax=None,
687
        scaling_mode: ScalingMode = ScalingMode.NO_SCALING,
688
        dq_dtype: jnp.dtype = jnp.bfloat16,
689
690
691
        data_layout: str = "NN",
        q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
        flatten_axis: int = -1,
692
693
694
        group_sizes: jnp.ndarray = None,
        original_shape: Tuple[int] = None,
        group_axis: int = 0,
695
696
        rowwise_has_rht_applied: bool = False,
        colwise_has_rht_applied: bool = False,
697
698
699
700
701
702
703
704
705
706
    ):
        """Creates a scaled tensor based on the quantization axis.

        Args:
            data: The quantized tensor data
            scale_inv: The inverse scaling factors
            colwise_data: The column-wise quantized data
            colwise_scale_inv: The column-wise inverse scaling factors
            scaling_mode: The scaling mode for quantization
            dq_dtype: The data type for dequantized values (default: bfloat16)
707
708
            data_layout: The data_layout specification (default: "NN")
            q_layout: The quantization axis (default: ROWWISE)
709
710
711
712
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array containing the size of each group (default: None)
            original_shape: The original shape of the tensor before grouping (default: None)
            group_axis: The axis along which grouping is performed (default: 0)
713
714
            rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
            colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False)
715
716

        Returns:
717
            Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
718
        """
719
720
        assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet"

721
        if q_layout.is_rowwise_colwise:
722
723
724
725
726
            return ScaledTensorFactory.create_2x(
                data,
                scale_inv,
                colwise_data,
                colwise_scale_inv,
727
                amax,
728
                colwise_amax,
729
730
                scaling_mode,
                dq_dtype,
731
732
                data_layout=data_layout,
                flatten_axis=flatten_axis,
733
734
735
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
736
737
                rowwise_has_rht_applied=rowwise_has_rht_applied,
                colwise_has_rht_applied=colwise_has_rht_applied,
738
739
            )

740
        if q_layout.is_colwise_only:
741
742
743
            return ScaledTensorFactory.create_1x(
                colwise_data,
                colwise_scale_inv,
744
                colwise_amax if colwise_amax is not None else amax,
745
746
                scaling_mode,
                dq_dtype,
747
                is_colwise=True,
748
749
750
751
752
                data_layout=data_layout[0],
                flatten_axis=flatten_axis,
                group_sizes=group_sizes,
                original_shape=original_shape,
                group_axis=group_axis,
753
                has_rht_applied=colwise_has_rht_applied,
754
755
            )

756
        return ScaledTensorFactory.create_1x(
757
758
            data,
            scale_inv,
759
            amax,
760
761
            scaling_mode,
            dq_dtype,
762
            is_colwise=False,
763
764
            data_layout=data_layout[0],
            flatten_axis=flatten_axis,
765
766
767
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
768
            has_rht_applied=rowwise_has_rht_applied,
769
770
771
772
773
774
775
776
777
778
779
780
781
        )


def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, ...]):
    """Applies sharding constraints to a tensor based on logical axis names.

    Args:
        x: The tensor to apply sharding constraints to
        logical_axis_names: Tuple of logical axis names for sharding

    Returns:
        The tensor with applied sharding constraints
    """
782
783
784
    if isinstance(x, GroupedScaledTensor1x):
        raise NotImplementedError

785
    if isinstance(x, AbstractBaseTensor):
786
        return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
787
788

    return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names)