scaling_modes.py 41.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Scaling mode implementations for quantization in JAX.

This module provides implementations of different scaling modes for tensor quantization,
including delayed scaling and block scaling strategies.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
16
from functools import reduce, lru_cache
17
import operator
18
import numpy as np
19

20
from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
21
22
23
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp

24
25
from transformer_engine_jax import JAXX_Scaling_Mode
from .misc import QuantizeLayout
26
from .device_utils import is_fp8_gemm_with_all_layouts_supported
27

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
__all__ = [
    "QuantizeShardyRules",
    "ScalingMode",
    "TensorUsage",
]


class TensorUsage(Enum):
    """Enum indicating tensor usage in GEMM operations.

    Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
    The tensor usage can be:
    - LHS: A is in the normal form
    - LHS_TRANS: A is in the transposed form
    - RHS: B is in the normal form
    - RHS_TRANS: B is in the transposed form

    The tensor usage is used in the ScaledTensor.get_tensor() method.
    """

    # LHS: Left-hand side, RHS: Right-hand side
    # LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
    LHS = 0
    LHS_TRANS = 1
    RHS = 2
    RHS_TRANS = 3

    def __eq__(self, other):
        if not isinstance(other, TensorUsage):
            return False
        return self.value == other.value

    def __hash__(self):
        return hash(self.value)
63
64


65
66
67
68
69
def DIVUP(a, b):
    "Divide a by b and then round up"
    return -(a // -b)


70
71
72
73
74
75
@dataclass
class QuantizeShardyRules:
    """Information necessary to shard scale tensors with Shardy.

    Attributes:
        input_spec: Specification for the input axes
76
77
78
79
80
        rowwise_out_spec: Sharding spec for the rowwise quantized data
        rowwise_scale_spec: Sharding spec for the rowwise scale
        colwise_out_spec: Sharding spec for the colwise quantized data
        colwise_scale_spec: Sharding spec for the colwise scale
        factor_sizes: For block scaling, contains the block size factor
81
82
83
    """

    input_spec: Tuple[str]
84
85
86
87
    rowwise_out_spec: Tuple[str]
    rowwise_scale_spec: Tuple[str]
    colwise_out_spec: Tuple[str]
    colwise_scale_spec: Tuple[str]
88
    factor_sizes: Dict[str, int]
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105


class ScalingModeMetadataImpl(ABC):
    """Base class for scaling mode implementations.

    This abstract class defines the interface for different scaling mode implementations,
    providing methods to get scale data types and shapes.
    """

    @abstractmethod
    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors.

        Returns:
            The data type used for scale tensors
        """

106
107
108
109
110
111
112
113
    @abstractmethod
    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """

114
115
    @abstractmethod
    def get_scale_shape(
116
117
        self,
        data_shape: Tuple[int, ...],
118
        data_layout: str = "N",
119
120
121
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
122
123
124
125
126
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors.

        Args:
            data_shape: The shape of the tensor being quantized
127
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
128
129
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

        Returns:
            The shape for scale tensors
        """

    @abstractmethod
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            n_groups: Number of groups in grouped quantization
            group_axis: The axis along which grouping is performed
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

150
151
152
153
        Returns:
            The shape for scale tensors
        """

154
155
156
157
158
159
160
161
162
163
164
165
    @lru_cache(maxsize=4)
    @abstractmethod
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """

166
167
    @abstractmethod
    def get_shardy_sharding_rules(
168
169
170
171
        self,
        input_shape,
        unique_var,
        flatten_axis,
172
        q_layout,
173
        broadcast_2d_scale_shape_to_1d,
174
        is_colwise_transposed,
175
176
177
178
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
179
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
180
            unique_var: An otherwise unused Shardy variable name prefix
181
            flatten_axis: Axis along which data can be flattened to 2D for quantization
182
            q_layout: The layout of the quantized tensor
183
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
184
            is_colwise_transposed: Whether the column-wise tensors are transposed.
185
186
187
188
189

        Returns:
            The Shardy rules for the scaling mode
        """

190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for no scaling mode.

    This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16.
    """

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling.

        Returns:
            The data type used for scale tensors (float32)
        """
        return jnp.float32

205
206
207
208
209
210
211
212
    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """
        return "NN"

213
214
215
    def get_scale_shape(
        self,
        data_shape: Tuple[int, ...],
216
        data_layout: str = "N",
217
218
219
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
220
        broadcast_2d_scale_shape_to_1d: bool = True,
221
222
223
224
225
226
227
228
229
230
231
232
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.

        Args:
            data_shape: The shape of the tensor being scaled
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors - (1,)
        """
233
234
235
236
237
238
239
240
        del (
            data_shape,
            data_layout,
            is_colwise,
            is_padded,
            flatten_axis,
            broadcast_2d_scale_shape_to_1d,
        )
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        return (0,)

    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        return QuantizeLayout.ROWWISE

    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        del data_shape, group_axis, is_colwise
        assert isinstance(n_groups, int)
        return (n_groups,)

    def get_shardy_sharding_rules(
274
275
276
277
        self,
        input_shape,
        unique_var,
        flatten_axis,
278
        q_layout,
279
        broadcast_2d_scale_shape_to_1d,
280
        is_colwise_transposed,
281
282
283
284
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
285
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
286
            unique_var: An otherwise unused Shardy variable name prefix
287
288
            flatten_axis: Axis along which data can be flattened to 2D for quantization
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
289
290
291
292

        Returns:
            The Shardy rules for the scaling mode
        """
293
294
295
296
297
298
299
300
301
302
303
        del broadcast_2d_scale_shape_to_1d
        input_spec = tuple(f"{unique_var}_x_{i}" for i in range(len(input_shape)))
        output_spec = tuple(input_spec)
        return QuantizeShardyRules(
            input_spec,
            output_spec,
            (BATCHING + f"{unique_var}_scale",),
            (BATCHING + f"{unique_var}_colwise_output",),
            (BATCHING + f"{unique_var}_colwise_scale",),
            {},
        )
304
305


306
307
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for current scaling mode.
308

309
    This implementation provides metadata for current scaling mode, including scale data type and shape.
310
311
312
313
314
315
316
317
318
319
    """

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors in delayed scaling.

        Returns:
            The data type used for scale tensors (float32)
        """
        return jnp.float32

320
321
322
323
324
325
326
327
    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """
        return "NT"

328
    def get_scale_shape(
329
330
        self,
        data_shape: Tuple[int, ...],
331
        data_layout: str = "N",
332
333
334
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
335
        broadcast_2d_scale_shape_to_1d: bool = True,
336
337
338
339
340
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors in delayed scaling.

        Args:
            data_shape: The shape of the tensor being scaled
341
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
342
343
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
344
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
345
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
346
347
348
349

        Returns:
            The shape for scale tensors - (1,)
        """
350
        del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d
351
352
        if np.prod(data_shape) == 0:
            return (0,)
353
354
        return (1,)

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        if is_fp8_gemm_with_all_layouts_supported():
            return QuantizeLayout.ROWWISE

        if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
            return QuantizeLayout.ROWWISE
        return QuantizeLayout.COLWISE

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        del data_shape, group_axis, is_colwise
        assert isinstance(n_groups, int)
        return (n_groups,)

390
    def get_shardy_sharding_rules(
391
392
393
394
        self,
        input_shape,
        unique_var,
        flatten_axis,
395
        q_layout,
396
        broadcast_2d_scale_shape_to_1d,
397
        is_colwise_transposed,
398
399
400
401
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
402
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
403
            unique_var: An otherwise unused Shardy variable name prefix
404
            flatten_axis: Axis along which data can be flattened to 2D for quantization
405
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
406
407
            q_layout: The layout of the quantized tensor
            is_colwise_transposed: Whether the colwise scaling is transposed
408
409
410
        Returns:
            The Shardy rules for the scaling mode
        """
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
        del broadcast_2d_scale_shape_to_1d
        input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape)))
        output_spec = input_spec
        colwise_output_spec = (BATCHING + f"{unique_var}_colwise_output",)

        if q_layout.has_colwise:
            from ..cpp_extensions.misc import multidim_transpose

            colwise_output_spec = input_spec
            if is_colwise_transposed:
                colwise_output_spec = multidim_transpose(
                    colwise_output_spec, transpose_axis=flatten_axis
                )
        scale = (BATCHING + unique_var + "_scale_inv",)
        return QuantizeShardyRules(input_spec, output_spec, scale, colwise_output_spec, scale, {})
426

427

428
429
430
431
432
433
434
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
    """Implementation for delayed scaling mode.

    This implementation provides metadata for delayed scaling mode, including scale data type and shape.
    """


435
436
437
438
439
440
441
442
443
444
445
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for block scaling mode.

    This implementation provides metadata for block scaling mode, which uses
    block-based scaling with specific alignment requirements.

    Attributes:
        _block_dims: Dimensions of the scaling blocks
        _block_alignment: Alignment requirements for blocks
    """

446
    def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str):
447
448
449
450
        """Initialize block scaling mode implementation.

        Args:
            block_dims: Dimensions of the scaling blocks
451
452
            scale_dtype: Data type of the scale tensor
            data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
453
454
        """
        self._block_dims = block_dims
455
        self._scale_dtype = scale_dtype
456
        self._block_alignment = (128, 4)
457
        self._data_layout = data_layout
458
459
460
461
462
463
464

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors in block scaling.

        Returns:
            The data type used for scale tensors (float8_e8m0fnu)
        """
465
466
467
468
469
470
471
472
473
        return self._scale_dtype

    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """
        return self._data_layout
474

475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
        """Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
        if len(data_shape) > 1:
            # handle last dim
            assert data_shape[-1] % scale_block_dim == 0
            last = data_shape[-1] // scale_block_dim
            scale_shape = (last,)
            assert n_scale_blocks % last == 0
            n_scale_blocks //= last
            # handle middle dim, exclude first and last
            for mid in reversed(data_shape[1:-1]):
                scale_shape = (mid,) + scale_shape
                assert n_scale_blocks % mid == 0
                n_scale_blocks //= mid
            scale_shape = (n_scale_blocks,) + scale_shape
        else:
            scale_shape = (n_scale_blocks,)

        assert len(scale_shape) == len(
            data_shape
        ), f"scale_shape {scale_shape}, data_shape {data_shape}"
        return scale_shape

498
    def get_scale_shape(
499
500
        self,
        data_shape: Tuple[int, ...],
501
        data_layout: str = "N",
502
503
504
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
505
        broadcast_2d_scale_shape_to_1d: bool = False,
506
507
508
509
510
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors in block scaling.

        Args:
            data_shape: The shape of the tensor being quantized
511
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
512
513
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
514
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
515
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
516
517
518
519

        Returns:
            The shape for scale tensors
        """
520
521
522
523
524
        flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape)
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"

525
526
        block_alignment = self._block_alignment if is_padded else (1, 1)

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        if is_colwise:
            assert data_layout == self._data_layout[1], (
                f"Data layout must match colwise layout, received {data_layout} but expected"
                f" {self._data_layout[1]}"
            )
        else:
            assert data_layout == self._data_layout[0], (
                f"Data layout must match rowwise layout, received {data_layout} but expected"
                f" {self._data_layout[0]}"
            )

        if is_colwise and self._data_layout[1] == "T":
            # TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value
            is_colwise = False  # now rowwise in T is colwise in N
            if flatten_axis < 0:
                flatten_axis = len(data_shape) + flatten_axis
            # flatten_axis is given wrt N layout, convert to T layout
            flatten_axis = len(data_shape) - flatten_axis

546
547
548
549
550
551
552
        if is_colwise:
            block_y, block_x = self._block_dims
            alignment_y, alignment_x = block_alignment
        else:
            block_x, block_y = self._block_dims
            alignment_x, alignment_y = block_alignment

553
        is_block_2d = block_x > 1 and block_y > 1
554
555
556
557
        assert data_shape[flatten_axis - 1] % block_x == 0, (
            f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
            f" {flatten_axis - 1}"
        )
558
559
        assert (
            data_shape[-1] % block_y == 0
560
        ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
561

562
563
564
        if broadcast_2d_scale_shape_to_1d and is_block_2d:
            block_x = 1

565
566
        flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
        flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
567

568
569
570
571
572
573
574
575
576
        assert flattened_first_dim % block_x == 0, (
            f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
            f" {data_shape} - should be divisible by block_x {block_x}"
        )
        assert flattened_last_dim % block_y == 0, (
            "Flattened last dim - mutiplication of"
            f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
            f" divisible by block_y {block_y}"
        )
577

578
579
        n_block_x = int(flattened_first_dim / block_x)
        n_block_y = int(flattened_last_dim / block_y)
580

581
582
583
        # padding
        n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
        n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
584

585
586
587
588
589
590
        first_dim_scale_shape = self._apply_scale_shape_correction(
            data_shape[:flatten_axis], n_block_x, block_x
        )
        last_dim_scale_shape = self._apply_scale_shape_correction(
            data_shape[flatten_axis:], n_block_y, block_y
        )
591

592
        return (*first_dim_scale_shape, *last_dim_scale_shape)
593

594
595
596
597
598
599
600
601
602
603
604
    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        # If we need to support 1x1x for inference in the future
605
        # if get_quantize_config().INFERENCE_MODE:
606
607
608
609
610
611
612
613
614
        #     assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
        #     if usage == TensorUsage.LHS:
        #         return QuantizeLayout.ROWWISE
        #     return QuantizeLayout.COLWISE

        if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
            return QuantizeLayout.ROWWISE
        return QuantizeLayout.COLWISE

615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for grouped scale tensors in this mode.
        If padded: The estimiated maximal possible shape for grouped scale tensor is return instead.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        assert isinstance(n_groups, int)
        block_alignment = self._block_alignment if is_padded else (1, 1)

        if is_colwise:
            block_y, block_x = self._block_dims
            alignment_y, alignment_x = block_alignment
        else:
            block_x, block_y = self._block_dims
            alignment_x, alignment_y = block_alignment

        if flatten_axis < 0:
            flatten_axis = len(data_shape) + flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"

        assert data_shape[flatten_axis - 1] % block_x == 0, (
            f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
            f" {flatten_axis - 1}"
        )
        assert (
            data_shape[-1] % block_y == 0
        ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"

        flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
        flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)

        assert flattened_first_dim % block_x == 0, (
            f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
            f" {data_shape} - should be divisible by block_x {block_x}"
        )
        assert flattened_last_dim % block_y == 0, (
            "Flattened last dim - mutiplication of"
            f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
            f" divisible by block_y {block_y}"
        )

        n_block_x = int(flattened_first_dim // block_x)
        n_block_y = int(flattened_last_dim // block_y)

        """
            Given the scale shape of [M, N], and G groups, and padding alignment (128, 4),
            The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows.
            Then:
                max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128
                max_padded_cols = DIVUP(N, 4) * 4
                max_scale_size = max_padded_rows * max_padded_cols
        """
        if is_padded:
            n_block_x = (n_groups - 1) * alignment_x + DIVUP(
                n_block_x - n_groups + 1, alignment_x
            ) * alignment_x
            n_block_y = DIVUP(n_block_y, alignment_y) * alignment_y

        return (n_block_x * n_block_y,)

686
    def get_shardy_sharding_rules(
687
688
689
690
        self,
        input_shape,
        unique_var,
        flatten_axis,
691
        q_layout,
692
        broadcast_2d_scale_shape_to_1d,
693
        is_colwise_transposed,
694
695
696
697
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
698
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
699
            unique_var: An otherwise unused Shardy variable name prefix
700
            flatten_axis: Axis along which data can be flattened to 2D for quantization
701
            q_layout: The layout of the quantized tensor
702
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
703
            is_colwise_transposed: Whether the column-wise tensors are transposed.
704
705
706
        Returns:
            The Shardy rules for the scaling mode
        """
707
708
709
        is_rowwise = q_layout.has_rowwise
        is_colwise = q_layout.has_colwise

710
711
        input_rank = len(input_shape)
        flatten_axis = (flatten_axis + input_rank) % input_rank
712
        input_spec = [f"{unique_var}_x_{i}" for i in range(input_rank)]
713

714
715
716
717
718
719
720
721
722
723
724
        assert (
            self._block_dims[1] != 1
        ), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}"
        # For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d
        if self._block_dims[0] != 1:
            assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, (
                f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d},"
                f" _block_dims={self._block_dims}"
            )

        block_size_1d = self._block_dims[1]
725
726
727

        # We have to use two different factors in the two CompoundFactors because of Shardy
        # verifier requirements, even though they are the same.
728
        # No CompoundFactor is needed if the dim has the same size as the blocksize
729
730
        blocksizes = {}
        rowwise_var = f"{unique_var}_None"
731
732
        colwise_var = f"{unique_var}_None"
        if is_rowwise and not input_shape[-1] == block_size_1d:
733
734
            rowwise_var = input_spec[-1] + "_compound"
            input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
735
            blocksizes["blocksize_x"] = block_size_1d
736
        if is_colwise and not input_shape[flatten_axis - 1] == block_size_1d:
737
738
            colwise_var = input_spec[flatten_axis - 1] + "_compound"
            input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
739
            blocksizes["blocksize_y"] = block_size_1d
740
741
742

        # The rowwise and colwise scale tensors should be sharded the same way as the input.
        # However, we need to adjust the dimensions where the block scaling factor applies.
743
744
745
746
747
748
749
750
751
752
753
        if is_rowwise:
            rowwise_out = input_spec.copy()
            rowwise_scale = input_spec.copy()
            rowwise_scale[-1] = rowwise_var
        else:
            rowwise_out = [
                BATCHING + f"{unique_var}_rowwise_output",
            ]
            rowwise_scale = [
                BATCHING + f"{unique_var}_rowwise_scale_inv",
            ]
754

755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        if is_colwise:
            colwise_out = input_spec.copy()
            colwise_scale = input_spec.copy()
            colwise_scale[flatten_axis - 1] = colwise_var
            if is_colwise_transposed:
                from ..cpp_extensions.misc import multidim_transpose

                colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis)
                colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis)
        else:
            colwise_out = [
                BATCHING + f"{unique_var}_colwise_output",
            ]
            colwise_scale = [
                BATCHING + f"{unique_var}_colwise_scale_inv",
            ]
771
772
773

        return QuantizeShardyRules(
            tuple(input_spec),
774
775
776
777
            tuple(rowwise_out),
            tuple(rowwise_scale),
            tuple(colwise_out),
            tuple(colwise_scale),
778
            blocksizes,
779
780
        )

781
782
783
784
785
786
787

@dataclass(frozen=True)
@register_pytree_node_class
class ScalingMode(Enum):
    """Enumeration of tensor scaling modes with their corresponding metadata implementations.

    This class defines the available scaling modes for tensor quantization:
788
789
    - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
    - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
790
    - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
791
792
    - NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
    - NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
793
    - NO_SCALING: No scaling applied
794
795
    """

796
797
798
    NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
    DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
    MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
799
    CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
800
801
    NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING
    NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824

    def _get_impl(self) -> ScalingModeMetadataImpl:
        """Get the implementation for this scaling mode.

        Returns:
            The scaling mode implementation

        Raises:
            ValueError: If the scaling mode is invalid
        """
        impl = SCALING_MODES_TO_IMPL.get(self)
        if impl is None:
            raise ValueError("Invalid scaling mode")
        return impl

    def get_scale_dtype(self):
        """Get the data type for scale tensors in this mode.

        Returns:
            The data type for scale tensors
        """
        return self._get_impl().get_scale_dtype()

825
826
827
    def get_scale_shape_2x(
        self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False
    ) -> Tuple[Tuple[int]]:
828
829
830
831
832
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            is_padded: Whether to use padded shapes
833
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
834
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
835
836
837
838

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
839
840
841
842
843
844
845
        data_layout = self._get_impl().get_data_layout()
        rowwise_layout = data_layout[0]
        assert (
            rowwise_layout == "N"
        ), f"For rowwise layout only 'N' is supported, received {rowwise_layout}"
        colwise_layout = data_layout[1]

846
        rowwise_scale_shape = self.get_scale_shape(
847
848
849
850
851
852
            data_shape,
            data_layout=rowwise_layout,
            is_colwise=False,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
853
        )
854
855
856
857

        colwise_data_shape = data_shape
        if colwise_layout == "T":
            colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis]
858
        colwise_scale_shape = self.get_scale_shape(
859
860
861
862
863
864
            colwise_data_shape,
            data_layout=colwise_layout,
            is_colwise=True,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
865
866
867
        )
        return (rowwise_scale_shape, colwise_scale_shape)

868
    def get_scale_shape(
869
870
871
872
873
874
875
        self,
        data_shape,
        data_layout="N",
        is_colwise=False,
        is_padded=True,
        flatten_axis=-1,
        broadcast_2d_scale_shape_to_1d=False,
876
    ) -> Tuple[int]:
877
878
879
880
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Shape of the data tensor
881
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
882
883
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
884
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
885
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
886
887
888
889

        Returns:
            The shape for scale tensors
        """
890
891
892
893
894
895
896
897
        return self._get_impl().get_scale_shape(
            data_shape,
            data_layout=data_layout,
            is_colwise=is_colwise,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
        )
898

899
900
901
902
903
904
905
906
907
908
909
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        return self._get_impl().get_quantize_layout(usage)

910
    def get_shardy_sharding_rules(
911
912
913
        self,
        input_shape,
        unique_var,
914
915
        flatten_axis,
        q_layout,
916
        broadcast_2d_scale_shape_to_1d=False,
917
918
919
920
    ) -> Tuple[Tuple[str]]:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
921
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
922
            unique_var: An otherwise unused Shardy variable name prefix
923
            flatten_axis: Axis along which data can be flattened to 2D for quantization.
924
            q_layout: The layout of the quantized tensor
925
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
926
927
928
929

        Returns:
            The Shardy rules for the scaling mode
        """
930
        return self._get_impl().get_shardy_sharding_rules(
931
932
933
934
935
936
            input_shape,
            unique_var,
            flatten_axis,
            q_layout,
            broadcast_2d_scale_shape_to_1d,
            self.is_colwise_transposed,
937
        )
938

939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
    def get_grouped_scale_shape_2x(
        self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
    ) -> Tuple[Tuple[int]]:
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            n_groups: Number of groups for grouped quantization
            group_axis: The axis along which grouping is performed
            is_padded: Whether to use padded shapes
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        rowwise_scale_shape = self.get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=False,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )
        colwise_scale_shape = self.get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=True,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )
        return (rowwise_scale_shape, colwise_scale_shape)

    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[Tuple[int]]:
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        return self._get_impl().get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=is_colwise,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )

994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
    def is_tensor_scaling(self) -> bool:
        """Check if this scaling mode is per-tensor scaling.

        Returns:
            True if the scaling mode is tensor scaling, False otherwise
        """
        return self in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        )

    def is_1d_block_scaling(self) -> bool:
        """Check if this scaling mode is 1D block scaling.

        Returns:
            True if the scaling mode is 1D block scaling, False otherwise
        """
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        # Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM.
        return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling

    @property
    def is_block_scaling(self) -> bool:
        """Check if this scaling mode is block scaling.

        Returns:
            True if the scaling mode is block scaling, False otherwise
        """
        # Currently we only have 1D block scaling modes
        return self.is_1d_block_scaling()

    def get_compatible_q_dtypes(self) -> set[jnp.dtype]:
        """Returns a set of compatible quantized data types for this scaling mode.

        Returns:
            A set of compatible quantized data types
        """
        if self in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
            ScalingMode.MXFP8_1D_SCALING,
        ):
            return {jnp.float8_e5m2, jnp.float8_e4m3fn}
        if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING):
            return {jnp.float4_e2m1fn}
        if self == ScalingMode.NO_SCALING:
            return {jnp.float16, jnp.bfloat16, jnp.float32}
        raise ValueError(f"Invalid scaling mode: {self}")

    @property
    def is_nvfp4_scaling(self) -> bool:
        """Check if this scaling mode is NVFP4 scaling.

        Returns:
            True if the scaling mode is NVFP4 scaling, False otherwise
        """
        return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING)

    @property
    def is_mxfp8_scaling(self) -> bool:
        """Check if this scaling mode is NVFP4 scaling.

        Returns:
            True if the scaling mode is NVFP4 scaling, False otherwise
        """
1058
1059
        return self == ScalingMode.MXFP8_1D_SCALING

1060
1061
1062
1063
1064
1065
1066
1067
1068
    @property
    def is_colwise_transposed(self) -> bool:
        """Check if this scaling mode uses transposed layout for column-wise scaling.

        Returns:
            True if the scaling mode uses transposed layout for column-wise scaling, False otherwise
        """
        return self.is_tensor_scaling() or self.is_nvfp4_scaling

1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    def __eq__(self, other):
        """Compare this scaling mode with another.

        Args:
            other: The other scaling mode to compare with

        Returns:
            True if the modes are equal, False otherwise
        """
        if not isinstance(other, ScalingMode):
            return False
        return self.value == other.value

    def tree_flatten(self):
        """Flatten this scaling mode for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        return (), (self.value)

    @classmethod
    def tree_unflatten(cls, aux_data, _children):
        """Reconstruct a scaling mode from its flattened representation.

        Args:
            aux_data: Auxiliary data containing the mode value
            _children: Unused children data

        Returns:
            A reconstructed ScalingMode instance
        """
        return cls(aux_data)


SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
1105
    ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
1106
    ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
1107
1108
1109
1110
1111
    ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(
        block_dims=(1, 32),
        scale_dtype=jnp.float8_e8m0fnu,
        data_layout="NN",
    ),
1112
    ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
1113
1114
1115
1116
1117
1118
1119
1120
    ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl(
        block_dims=(1, 16),
        scale_dtype=jnp.float8_e4m3fn,
        data_layout="NT",
    ),
    ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl(
        block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT"
    ),
1121
}