scaling_modes.py 38.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Scaling mode implementations for quantization in JAX.

This module provides implementations of different scaling modes for tensor quantization,
including delayed scaling and block scaling strategies.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
16
from functools import reduce, lru_cache
17
import operator
18
import numpy as np
19

20
from jax.experimental.custom_partitioning import BATCHING, CompoundFactor
21
22
23
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp

24
25
from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
26

27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
__all__ = [
    "QuantizeShardyRules",
    "ScalingMode",
    "TensorUsage",
]


class TensorUsage(Enum):
    """Enum indicating tensor usage in GEMM operations.

    Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
    The tensor usage can be:
    - LHS: A is in the normal form
    - LHS_TRANS: A is in the transposed form
    - RHS: B is in the normal form
    - RHS_TRANS: B is in the transposed form

    The tensor usage is used in the ScaledTensor.get_tensor() method.
    """

    # LHS: Left-hand side, RHS: Right-hand side
    # LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
    LHS = 0
    LHS_TRANS = 1
    RHS = 2
    RHS_TRANS = 3

    def __eq__(self, other):
        if not isinstance(other, TensorUsage):
            return False
        return self.value == other.value

    def __hash__(self):
        return hash(self.value)
62
63


64
65
66
67
68
def DIVUP(a, b):
    "Divide a by b and then round up"
    return -(a // -b)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@dataclass
class QuantizeShardyRules:
    """Information necessary to shard scale tensors with Shardy.

    Attributes:
        input_spec: Specification for the input axes
        rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
          the axes in `input_spec`
        colwise_rule: Likewise for the column-wise scale tensor.
        factor_sizes: For block scaling, contains the block size factor, which is
          used in `input_spec`.
    """

    input_spec: Tuple[str]
    rowwise_rule: Tuple[str]
    colwise_rule: Tuple[str]
    factor_sizes: Dict[str, int]
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102


class ScalingModeMetadataImpl(ABC):
    """Base class for scaling mode implementations.

    This abstract class defines the interface for different scaling mode implementations,
    providing methods to get scale data types and shapes.
    """

    @abstractmethod
    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors.

        Returns:
            The data type used for scale tensors
        """

103
104
105
106
107
108
109
110
    @abstractmethod
    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """

111
112
    @abstractmethod
    def get_scale_shape(
113
114
        self,
        data_shape: Tuple[int, ...],
115
        data_layout: str = "N",
116
117
118
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
119
120
121
122
123
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors.

        Args:
            data_shape: The shape of the tensor being quantized
124
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
125
126
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

        Returns:
            The shape for scale tensors
        """

    @abstractmethod
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            n_groups: Number of groups in grouped quantization
            group_axis: The axis along which grouping is performed
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

147
148
149
150
        Returns:
            The shape for scale tensors
        """

151
152
153
154
155
156
157
158
159
160
161
162
    @lru_cache(maxsize=4)
    @abstractmethod
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """

163
164
    @abstractmethod
    def get_shardy_sharding_rules(
165
166
167
168
        self,
        input_shape,
        unique_var,
        flatten_axis,
169
        broadcast_2d_scale_shape_to_1d,
170
171
172
173
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
174
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
175
            unique_var: An otherwise unused Shardy variable name prefix
176
177
            flatten_axis: Axis along which data can be flattened to 2D for quantization
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
178
179
180
181
182

        Returns:
            The Shardy rules for the scaling mode
        """

183

184
185
186
187
188
189
190
191
192
193
194
195
196
197
class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for no scaling mode.

    This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16.
    """

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling.

        Returns:
            The data type used for scale tensors (float32)
        """
        return jnp.float32

198
199
200
201
202
203
204
205
    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """
        return "NN"

206
207
208
    def get_scale_shape(
        self,
        data_shape: Tuple[int, ...],
209
        data_layout: str = "N",
210
211
212
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
213
        broadcast_2d_scale_shape_to_1d: bool = True,
214
215
216
217
218
219
220
221
222
223
224
225
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.

        Args:
            data_shape: The shape of the tensor being scaled
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors - (1,)
        """
226
227
228
229
230
231
232
233
        del (
            data_shape,
            data_layout,
            is_colwise,
            is_padded,
            flatten_axis,
            broadcast_2d_scale_shape_to_1d,
        )
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        return (0,)

    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        return QuantizeLayout.ROWWISE

    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        del data_shape, group_axis, is_colwise
        assert isinstance(n_groups, int)
        return (n_groups,)

    def get_shardy_sharding_rules(
267
268
269
270
        self,
        input_shape,
        unique_var,
        flatten_axis,
271
        broadcast_2d_scale_shape_to_1d,
272
273
274
275
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
276
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
277
            unique_var: An otherwise unused Shardy variable name prefix
278
279
            flatten_axis: Axis along which data can be flattened to 2D for quantization
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
280
281
282
283

        Returns:
            The Shardy rules for the scaling mode
        """
284
        del flatten_axis, broadcast_2d_scale_shape_to_1d
285
        input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
286
287
288
289
        scale_var = BATCHING + unique_var + "_scale_inv"
        return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})


290
291
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for current scaling mode.
292

293
    This implementation provides metadata for current scaling mode, including scale data type and shape.
294
295
296
297
298
299
300
301
302
303
    """

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors in delayed scaling.

        Returns:
            The data type used for scale tensors (float32)
        """
        return jnp.float32

304
305
306
307
308
309
310
311
    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """
        return "NT"

312
    def get_scale_shape(
313
314
        self,
        data_shape: Tuple[int, ...],
315
        data_layout: str = "N",
316
317
318
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
319
        broadcast_2d_scale_shape_to_1d: bool = True,
320
321
322
323
324
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors in delayed scaling.

        Args:
            data_shape: The shape of the tensor being scaled
325
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
326
327
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
328
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
329
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
330
331
332
333

        Returns:
            The shape for scale tensors - (1,)
        """
334
        del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d
335
336
        if np.prod(data_shape) == 0:
            return (0,)
337
338
        return (1,)

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        if is_fp8_gemm_with_all_layouts_supported():
            return QuantizeLayout.ROWWISE

        if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
            return QuantizeLayout.ROWWISE
        return QuantizeLayout.COLWISE

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        del data_shape, group_axis, is_colwise
        assert isinstance(n_groups, int)
        return (n_groups,)

374
    def get_shardy_sharding_rules(
375
376
377
378
        self,
        input_shape,
        unique_var,
        flatten_axis,
379
        broadcast_2d_scale_shape_to_1d,
380
381
382
383
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
384
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
385
            unique_var: An otherwise unused Shardy variable name prefix
386
            flatten_axis: Axis along which data can be flattened to 2D for quantization
387
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
388
389
390
391

        Returns:
            The Shardy rules for the scaling mode
        """
392
        del flatten_axis, broadcast_2d_scale_shape_to_1d
393
        input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape)))
Alp Dener's avatar
Alp Dener committed
394
395
        scale_var = BATCHING + unique_var + "_scale_inv"
        return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
396

397

398
399
400
401
402
403
404
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
    """Implementation for delayed scaling mode.

    This implementation provides metadata for delayed scaling mode, including scale data type and shape.
    """


405
406
407
408
409
410
411
412
413
414
415
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for block scaling mode.

    This implementation provides metadata for block scaling mode, which uses
    block-based scaling with specific alignment requirements.

    Attributes:
        _block_dims: Dimensions of the scaling blocks
        _block_alignment: Alignment requirements for blocks
    """

416
    def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str):
417
418
419
420
        """Initialize block scaling mode implementation.

        Args:
            block_dims: Dimensions of the scaling blocks
421
422
            scale_dtype: Data type of the scale tensor
            data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
423
424
        """
        self._block_dims = block_dims
425
        self._scale_dtype = scale_dtype
426
        self._block_alignment = (128, 4)
427
        self._data_layout = data_layout
428
429
430
431
432
433
434

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors in block scaling.

        Returns:
            The data type used for scale tensors (float8_e8m0fnu)
        """
435
436
437
438
439
440
441
442
443
        return self._scale_dtype

    def get_data_layout(self) -> str:
        """Get the data layout for rowwise and colwise scaling.

        Returns:
            The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout.
        """
        return self._data_layout
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
        """Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
        if len(data_shape) > 1:
            # handle last dim
            assert data_shape[-1] % scale_block_dim == 0
            last = data_shape[-1] // scale_block_dim
            scale_shape = (last,)
            assert n_scale_blocks % last == 0
            n_scale_blocks //= last
            # handle middle dim, exclude first and last
            for mid in reversed(data_shape[1:-1]):
                scale_shape = (mid,) + scale_shape
                assert n_scale_blocks % mid == 0
                n_scale_blocks //= mid
            scale_shape = (n_scale_blocks,) + scale_shape
        else:
            scale_shape = (n_scale_blocks,)

        assert len(scale_shape) == len(
            data_shape
        ), f"scale_shape {scale_shape}, data_shape {data_shape}"
        return scale_shape

468
    def get_scale_shape(
469
470
        self,
        data_shape: Tuple[int, ...],
471
        data_layout: str = "N",
472
473
474
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
475
        broadcast_2d_scale_shape_to_1d: bool = False,
476
477
478
479
480
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors in block scaling.

        Args:
            data_shape: The shape of the tensor being quantized
481
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
482
483
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
484
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
485
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True.
486
487
488
489

        Returns:
            The shape for scale tensors
        """
490
491
492
493
494
        flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape)
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"

495
496
        block_alignment = self._block_alignment if is_padded else (1, 1)

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        if is_colwise:
            assert data_layout == self._data_layout[1], (
                f"Data layout must match colwise layout, received {data_layout} but expected"
                f" {self._data_layout[1]}"
            )
        else:
            assert data_layout == self._data_layout[0], (
                f"Data layout must match rowwise layout, received {data_layout} but expected"
                f" {self._data_layout[0]}"
            )

        if is_colwise and self._data_layout[1] == "T":
            # TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value
            is_colwise = False  # now rowwise in T is colwise in N
            if flatten_axis < 0:
                flatten_axis = len(data_shape) + flatten_axis
            # flatten_axis is given wrt N layout, convert to T layout
            flatten_axis = len(data_shape) - flatten_axis

516
517
518
519
520
521
522
        if is_colwise:
            block_y, block_x = self._block_dims
            alignment_y, alignment_x = block_alignment
        else:
            block_x, block_y = self._block_dims
            alignment_x, alignment_y = block_alignment

523
        is_block_2d = block_x > 1 and block_y > 1
524
525
526
527
        assert data_shape[flatten_axis - 1] % block_x == 0, (
            f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
            f" {flatten_axis - 1}"
        )
528
529
        assert (
            data_shape[-1] % block_y == 0
530
        ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
531

532
533
534
        if broadcast_2d_scale_shape_to_1d and is_block_2d:
            block_x = 1

535
536
        flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
        flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
537

538
539
540
541
542
543
544
545
546
        assert flattened_first_dim % block_x == 0, (
            f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
            f" {data_shape} - should be divisible by block_x {block_x}"
        )
        assert flattened_last_dim % block_y == 0, (
            "Flattened last dim - mutiplication of"
            f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
            f" divisible by block_y {block_y}"
        )
547

548
549
        n_block_x = int(flattened_first_dim / block_x)
        n_block_y = int(flattened_last_dim / block_y)
550

551
552
553
        # padding
        n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
        n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
554

555
556
557
558
559
560
        first_dim_scale_shape = self._apply_scale_shape_correction(
            data_shape[:flatten_axis], n_block_x, block_x
        )
        last_dim_scale_shape = self._apply_scale_shape_correction(
            data_shape[flatten_axis:], n_block_y, block_y
        )
561

562
        return (*first_dim_scale_shape, *last_dim_scale_shape)
563

564
565
566
567
568
569
570
571
572
573
574
    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        # If we need to support 1x1x for inference in the future
575
        # if get_quantize_config().INFERENCE_MODE:
576
577
578
579
580
581
582
583
584
        #     assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
        #     if usage == TensorUsage.LHS:
        #         return QuantizeLayout.ROWWISE
        #     return QuantizeLayout.COLWISE

        if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
            return QuantizeLayout.ROWWISE
        return QuantizeLayout.COLWISE

585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for grouped scale tensors in this mode.
        If padded: The estimiated maximal possible shape for grouped scale tensor is return instead.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        assert isinstance(n_groups, int)
        block_alignment = self._block_alignment if is_padded else (1, 1)

        if is_colwise:
            block_y, block_x = self._block_dims
            alignment_y, alignment_x = block_alignment
        else:
            block_x, block_y = self._block_dims
            alignment_x, alignment_y = block_alignment

        if flatten_axis < 0:
            flatten_axis = len(data_shape) + flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"

        assert data_shape[flatten_axis - 1] % block_x == 0, (
            f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
            f" {flatten_axis - 1}"
        )
        assert (
            data_shape[-1] % block_y == 0
        ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"

        flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
        flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)

        assert flattened_first_dim % block_x == 0, (
            f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
            f" {data_shape} - should be divisible by block_x {block_x}"
        )
        assert flattened_last_dim % block_y == 0, (
            "Flattened last dim - mutiplication of"
            f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
            f" divisible by block_y {block_y}"
        )

        n_block_x = int(flattened_first_dim // block_x)
        n_block_y = int(flattened_last_dim // block_y)

        """
            Given the scale shape of [M, N], and G groups, and padding alignment (128, 4),
            The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows.
            Then:
                max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128
                max_padded_cols = DIVUP(N, 4) * 4
                max_scale_size = max_padded_rows * max_padded_cols
        """
        if is_padded:
            n_block_x = (n_groups - 1) * alignment_x + DIVUP(
                n_block_x - n_groups + 1, alignment_x
            ) * alignment_x
            n_block_y = DIVUP(n_block_y, alignment_y) * alignment_y

        return (n_block_x * n_block_y,)

656
    def get_shardy_sharding_rules(
657
658
659
660
        self,
        input_shape,
        unique_var,
        flatten_axis,
661
        broadcast_2d_scale_shape_to_1d,
662
663
664
665
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
666
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
667
            unique_var: An otherwise unused Shardy variable name prefix
668
            flatten_axis: Axis along which data can be flattened to 2D for quantization
669
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D.
670
671
672
673

        Returns:
            The Shardy rules for the scaling mode
        """
674
        # TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed
675
676
677
678
        input_rank = len(input_shape)
        input_spec = [f"{unique_var}_{i}" for i in range(input_rank)]
        flatten_axis = (flatten_axis + input_rank) % input_rank

679
680
681
682
683
684
685
686
687
688
689
        assert (
            self._block_dims[1] != 1
        ), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}"
        # For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d
        if self._block_dims[0] != 1:
            assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, (
                f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d},"
                f" _block_dims={self._block_dims}"
            )

        block_size_1d = self._block_dims[1]
690
691
692
693
694
695

        # We have to use two different factors in the two CompoundFactors because of Shardy
        # verifier requirements, even though they are the same.
        blocksizes = {}
        colwise_var = f"{unique_var}_None"
        rowwise_var = f"{unique_var}_None"
696
        if not input_shape[-1] == block_size_1d:
697
698
            rowwise_var = input_spec[-1] + "_compound"
            input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x")
699
700
            blocksizes["blocksize_x"] = block_size_1d
        if not input_shape[flatten_axis - 1] == block_size_1d:
701
702
            colwise_var = input_spec[flatten_axis - 1] + "_compound"
            input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y")
703
            blocksizes["blocksize_y"] = block_size_1d
704
705
706
707
708
709
710
711

        # The rowwise and colwise scale tensors should be sharded the same way as the input.
        # However, we need to adjust the dimensions where the block scaling factor applies.
        rowwise = input_spec.copy()
        rowwise[-1] = rowwise_var

        colwise = input_spec.copy()
        colwise[flatten_axis - 1] = colwise_var
712
713
714
715
716

        return QuantizeShardyRules(
            tuple(input_spec),
            tuple(rowwise),
            tuple(colwise),
717
            blocksizes,
718
719
        )

720
721
722
723
724
725
726

@dataclass(frozen=True)
@register_pytree_node_class
class ScalingMode(Enum):
    """Enumeration of tensor scaling modes with their corresponding metadata implementations.

    This class defines the available scaling modes for tensor quantization:
727
728
    - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
    - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
729
    - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
730
731
    - NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
    - NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales
732
    - NO_SCALING: No scaling applied
733
734
    """

735
736
737
    NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
    DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
    MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
738
    CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
739
740
    NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING
    NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763

    def _get_impl(self) -> ScalingModeMetadataImpl:
        """Get the implementation for this scaling mode.

        Returns:
            The scaling mode implementation

        Raises:
            ValueError: If the scaling mode is invalid
        """
        impl = SCALING_MODES_TO_IMPL.get(self)
        if impl is None:
            raise ValueError("Invalid scaling mode")
        return impl

    def get_scale_dtype(self):
        """Get the data type for scale tensors in this mode.

        Returns:
            The data type for scale tensors
        """
        return self._get_impl().get_scale_dtype()

764
765
766
    def get_scale_shape_2x(
        self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False
    ) -> Tuple[Tuple[int]]:
767
768
769
770
771
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            is_padded: Whether to use padded shapes
772
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
773
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
774
775
776
777

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
778
779
780
781
782
783
784
        data_layout = self._get_impl().get_data_layout()
        rowwise_layout = data_layout[0]
        assert (
            rowwise_layout == "N"
        ), f"For rowwise layout only 'N' is supported, received {rowwise_layout}"
        colwise_layout = data_layout[1]

785
        rowwise_scale_shape = self.get_scale_shape(
786
787
788
789
790
791
            data_shape,
            data_layout=rowwise_layout,
            is_colwise=False,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
792
        )
793
794
795
796

        colwise_data_shape = data_shape
        if colwise_layout == "T":
            colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis]
797
        colwise_scale_shape = self.get_scale_shape(
798
799
800
801
802
803
            colwise_data_shape,
            data_layout=colwise_layout,
            is_colwise=True,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
804
805
806
        )
        return (rowwise_scale_shape, colwise_scale_shape)

807
    def get_scale_shape(
808
809
810
811
812
813
814
        self,
        data_shape,
        data_layout="N",
        is_colwise=False,
        is_padded=True,
        flatten_axis=-1,
        broadcast_2d_scale_shape_to_1d=False,
815
    ) -> Tuple[int]:
816
817
818
819
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Shape of the data tensor
820
            data_layout: Layout of the data shape, either "N" (default) or "T" for transposed.
821
822
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
823
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
824
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
825
826
827
828

        Returns:
            The shape for scale tensors
        """
829
830
831
832
833
834
835
836
        return self._get_impl().get_scale_shape(
            data_shape,
            data_layout=data_layout,
            is_colwise=is_colwise,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d,
        )
837

838
839
840
841
842
843
844
845
846
847
848
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        return self._get_impl().get_quantize_layout(usage)

849
    def get_shardy_sharding_rules(
850
851
852
853
        self,
        input_shape,
        unique_var,
        flatten_axis=-1,
854
        broadcast_2d_scale_shape_to_1d=False,
855
856
857
858
    ) -> Tuple[Tuple[str]]:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
859
            input_shape: The shape of the input tensor (for which we produce the scale tensor)
860
            unique_var: An otherwise unused Shardy variable name prefix
861
            flatten_axis: Axis along which data can be flattened to 2D for quantization.
862
            broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False.
863
864
865
866

        Returns:
            The Shardy rules for the scaling mode
        """
867
868
869
        return self._get_impl().get_shardy_sharding_rules(
            input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d
        )
870

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
    def get_grouped_scale_shape_2x(
        self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
    ) -> Tuple[Tuple[int]]:
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            n_groups: Number of groups for grouped quantization
            group_axis: The axis along which grouping is performed
            is_padded: Whether to use padded shapes
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        rowwise_scale_shape = self.get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=False,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )
        colwise_scale_shape = self.get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=True,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )
        return (rowwise_scale_shape, colwise_scale_shape)

    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[Tuple[int]]:
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        return self._get_impl().get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=is_colwise,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )

926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    def is_tensor_scaling(self) -> bool:
        """Check if this scaling mode is per-tensor scaling.

        Returns:
            True if the scaling mode is tensor scaling, False otherwise
        """
        return self in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        )

    def is_1d_block_scaling(self) -> bool:
        """Check if this scaling mode is 1D block scaling.

        Returns:
            True if the scaling mode is 1D block scaling, False otherwise
        """
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        # Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM.
        return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling

    @property
    def is_block_scaling(self) -> bool:
        """Check if this scaling mode is block scaling.

        Returns:
            True if the scaling mode is block scaling, False otherwise
        """
        # Currently we only have 1D block scaling modes
        return self.is_1d_block_scaling()

    def get_compatible_q_dtypes(self) -> set[jnp.dtype]:
        """Returns a set of compatible quantized data types for this scaling mode.

        Returns:
            A set of compatible quantized data types
        """
        if self in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
            ScalingMode.MXFP8_1D_SCALING,
        ):
            return {jnp.float8_e5m2, jnp.float8_e4m3fn}
        if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING):
            return {jnp.float4_e2m1fn}
        if self == ScalingMode.NO_SCALING:
            return {jnp.float16, jnp.bfloat16, jnp.float32}
        raise ValueError(f"Invalid scaling mode: {self}")

    @property
    def is_nvfp4_scaling(self) -> bool:
        """Check if this scaling mode is NVFP4 scaling.

        Returns:
            True if the scaling mode is NVFP4 scaling, False otherwise
        """
        return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING)

    @property
    def is_mxfp8_scaling(self) -> bool:
        """Check if this scaling mode is NVFP4 scaling.

        Returns:
            True if the scaling mode is NVFP4 scaling, False otherwise
        """
990
991
        return self == ScalingMode.MXFP8_1D_SCALING

992
993
994
995
996
997
998
999
1000
    @property
    def is_colwise_transposed(self) -> bool:
        """Check if this scaling mode uses transposed layout for column-wise scaling.

        Returns:
            True if the scaling mode uses transposed layout for column-wise scaling, False otherwise
        """
        return self.is_tensor_scaling() or self.is_nvfp4_scaling

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
    def __eq__(self, other):
        """Compare this scaling mode with another.

        Args:
            other: The other scaling mode to compare with

        Returns:
            True if the modes are equal, False otherwise
        """
        if not isinstance(other, ScalingMode):
            return False
        return self.value == other.value

    def tree_flatten(self):
        """Flatten this scaling mode for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        return (), (self.value)

    @classmethod
    def tree_unflatten(cls, aux_data, _children):
        """Reconstruct a scaling mode from its flattened representation.

        Args:
            aux_data: Auxiliary data containing the mode value
            _children: Unused children data

        Returns:
            A reconstructed ScalingMode instance
        """
        return cls(aux_data)


SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
1037
    ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
1038
    ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
1039
1040
1041
1042
1043
    ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(
        block_dims=(1, 32),
        scale_dtype=jnp.float8_e8m0fnu,
        data_layout="NN",
    ),
1044
    ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
1045
1046
1047
1048
1049
1050
1051
1052
    ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl(
        block_dims=(1, 16),
        scale_dtype=jnp.float8_e4m3fn,
        data_layout="NT",
    ),
    ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl(
        block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT"
    ),
1053
}