scaling_modes.py 29.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Scaling mode implementations for quantization in JAX.

This module provides implementations of different scaling modes for tensor quantization,
including delayed scaling and block scaling strategies.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Tuple, Dict
16
from functools import reduce, lru_cache
17
import operator
18
import numpy as np
19

Alp Dener's avatar
Alp Dener committed
20
from jax.experimental.custom_partitioning import BATCHING
21
22
23
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp

24
25
from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout
from .device_utils import is_fp8_gemm_with_all_layouts_supported
26

27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
__all__ = [
    "QuantizeShardyRules",
    "ScalingMode",
    "TensorUsage",
]


class TensorUsage(Enum):
    """Enum indicating tensor usage in GEMM operations.

    Given a GEMM operation: C = A * B in which A and B can be in the normal or transposed form.
    The tensor usage can be:
    - LHS: A is in the normal form
    - LHS_TRANS: A is in the transposed form
    - RHS: B is in the normal form
    - RHS_TRANS: B is in the transposed form

    The tensor usage is used in the ScaledTensor.get_tensor() method.
    """

    # LHS: Left-hand side, RHS: Right-hand side
    # LHS_TRANS: Left-hand side transposed, RHS_TRANS: Right-hand side transposed
    LHS = 0
    LHS_TRANS = 1
    RHS = 2
    RHS_TRANS = 3

    def __eq__(self, other):
        if not isinstance(other, TensorUsage):
            return False
        return self.value == other.value

    def __hash__(self):
        return hash(self.value)
62
63


64
65
66
67
68
def DIVUP(a, b):
    "Divide a by b and then round up"
    return -(a // -b)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@dataclass
class QuantizeShardyRules:
    """Information necessary to shard scale tensors with Shardy.

    Attributes:
        input_spec: Specification for the input axes
        rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
          the axes in `input_spec`
        colwise_rule: Likewise for the column-wise scale tensor.
        factor_sizes: For block scaling, contains the block size factor, which is
          used in `input_spec`.
    """

    input_spec: Tuple[str]
    rowwise_rule: Tuple[str]
    colwise_rule: Tuple[str]
    factor_sizes: Dict[str, int]
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


class ScalingModeMetadataImpl(ABC):
    """Base class for scaling mode implementations.

    This abstract class defines the interface for different scaling mode implementations,
    providing methods to get scale data types and shapes.
    """

    @abstractmethod
    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors.

        Returns:
            The data type used for scale tensors
        """

    @abstractmethod
    def get_scale_shape(
105
106
107
108
109
        self,
        data_shape: Tuple[int, ...],
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
110
111
112
113
114
115
116
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors.

        Args:
            data_shape: The shape of the tensor being quantized
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

        Returns:
            The shape for scale tensors
        """

    @abstractmethod
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            n_groups: Number of groups in grouped quantization
            group_axis: The axis along which grouping is performed
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

137
138
139
140
        Returns:
            The shape for scale tensors
        """

141
142
143
144
145
146
147
148
149
150
151
152
    @lru_cache(maxsize=4)
    @abstractmethod
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    @abstractmethod
    def get_shardy_sharding_rules(
        self, input_rank, unique_var, flatten_axis
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
            input_rank: The rank of the input tensor (for which we produce the scale tensor)
            unique_var: An otherwise unused Shardy variable name prefix
            flatten_axis: Axis along which data can be flattened to 2D for quantization.

        Returns:
            The Shardy rules for the scaling mode
        """

168

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class NoScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for no scaling mode.

    This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16.
    """

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling.

        Returns:
            The data type used for scale tensors (float32)
        """
        return jnp.float32

    def get_scale_shape(
        self,
        data_shape: Tuple[int, ...],
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling.

        Args:
            data_shape: The shape of the tensor being scaled
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors - (1,)
        """
        del data_shape, is_colwise, is_padded, flatten_axis
        return (0,)

    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        return QuantizeLayout.ROWWISE

    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        del data_shape, group_axis, is_colwise
        assert isinstance(n_groups, int)
        return (n_groups,)

    def get_shardy_sharding_rules(
        self, input_rank, unique_var, flatten_axis
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
            input_rank: The rank of the input tensor (for which we produce the scale tensor)
            unique_var: An otherwise unused Shardy variable name prefix
            flatten_axis: Axis along which data can be flattened to 2D for quantization.

        Returns:
            The Shardy rules for the scaling mode
        """
        del flatten_axis
        input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
        scale_var = BATCHING + unique_var + "_scale_inv"
        return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})


253
254
class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for current scaling mode.
255

256
    This implementation provides metadata for current scaling mode, including scale data type and shape.
257
258
259
260
261
262
263
264
265
266
267
    """

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors in delayed scaling.

        Returns:
            The data type used for scale tensors (float32)
        """
        return jnp.float32

    def get_scale_shape(
268
269
270
271
272
        self,
        data_shape: Tuple[int, ...],
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
273
274
275
276
277
278
279
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors in delayed scaling.

        Args:
            data_shape: The shape of the tensor being scaled
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
280
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
281
282
283
284

        Returns:
            The shape for scale tensors - (1,)
        """
285
286
287
        del is_colwise
        if np.prod(data_shape) == 0:
            return (0,)
288
289
        return (1,)

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        if is_fp8_gemm_with_all_layouts_supported():
            return QuantizeLayout.ROWWISE

        if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
            return QuantizeLayout.ROWWISE
        return QuantizeLayout.COLWISE

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        del data_shape, group_axis, is_colwise
        assert isinstance(n_groups, int)
        return (n_groups,)

325
326
327
328
329
330
331
332
333
334
335
336
337
338
    def get_shardy_sharding_rules(
        self, input_rank, unique_var, flatten_axis
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
            input_rank: The rank of the input tensor (for which we produce the scale tensor)
            unique_var: An otherwise unused Shardy variable name prefix
            flatten_axis: Axis along which data can be flattened to 2D for quantization.

        Returns:
            The Shardy rules for the scaling mode
        """
        del flatten_axis
Alp Dener's avatar
Alp Dener committed
339
340
341
        input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank))
        scale_var = BATCHING + unique_var + "_scale_inv"
        return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {})
342

343

344
345
346
347
348
349
350
class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl):
    """Implementation for delayed scaling mode.

    This implementation provides metadata for delayed scaling mode, including scale data type and shape.
    """


351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
    """Implementation for block scaling mode.

    This implementation provides metadata for block scaling mode, which uses
    block-based scaling with specific alignment requirements.

    Attributes:
        _block_dims: Dimensions of the scaling blocks
        _block_alignment: Alignment requirements for blocks
    """

    def __init__(self, block_dims: Tuple[int]):
        """Initialize block scaling mode implementation.

        Args:
            block_dims: Dimensions of the scaling blocks
        """
        self._block_dims = block_dims
        self._block_alignment = (128, 4)

    def get_scale_dtype(self) -> jnp.dtype:
        """Get the data type for scale tensors in block scaling.

        Returns:
            The data type used for scale tensors (float8_e8m0fnu)
        """
        return jnp.float8_e8m0fnu

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim):
        """Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
        if len(data_shape) > 1:
            # handle last dim
            assert data_shape[-1] % scale_block_dim == 0
            last = data_shape[-1] // scale_block_dim
            scale_shape = (last,)
            assert n_scale_blocks % last == 0
            n_scale_blocks //= last
            # handle middle dim, exclude first and last
            for mid in reversed(data_shape[1:-1]):
                scale_shape = (mid,) + scale_shape
                assert n_scale_blocks % mid == 0
                n_scale_blocks //= mid
            scale_shape = (n_scale_blocks,) + scale_shape
        else:
            scale_shape = (n_scale_blocks,)

        assert len(scale_shape) == len(
            data_shape
        ), f"scale_shape {scale_shape}, data_shape {data_shape}"
        return scale_shape

402
    def get_scale_shape(
403
404
405
406
407
        self,
        data_shape: Tuple[int, ...],
        is_colwise: bool = False,
        is_padded: bool = True,
        flatten_axis: int = -1,
408
409
410
411
412
413
414
    ) -> Tuple[int, ...]:
        """Get the shape for scale tensors in block scaling.

        Args:
            data_shape: The shape of the tensor being quantized
            is_colwise: Whether the scaling is column-wise
            is_padded: Whether to return padded shape
415
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
416
417
418
419
420
421
422
423
424
425
426
427
428

        Returns:
            The shape for scale tensors
        """
        block_alignment = self._block_alignment if is_padded else (1, 1)

        if is_colwise:
            block_y, block_x = self._block_dims
            alignment_y, alignment_x = block_alignment
        else:
            block_x, block_y = self._block_dims
            alignment_x, alignment_y = block_alignment

429
430
        if flatten_axis < 0:
            flatten_axis = len(data_shape) + flatten_axis
431
        assert (
432
433
434
435
436
437
438
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"

        assert data_shape[flatten_axis - 1] % block_x == 0, (
            f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
            f" {flatten_axis - 1}"
        )
439
440
        assert (
            data_shape[-1] % block_y == 0
441
        ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
442

443
444
        flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
        flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
445

446
447
448
449
450
451
452
453
454
        assert flattened_first_dim % block_x == 0, (
            f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
            f" {data_shape} - should be divisible by block_x {block_x}"
        )
        assert flattened_last_dim % block_y == 0, (
            "Flattened last dim - mutiplication of"
            f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
            f" divisible by block_y {block_y}"
        )
455

456
457
        n_block_x = int(flattened_first_dim / block_x)
        n_block_y = int(flattened_last_dim / block_y)
458

459
460
461
        # padding
        n_block_x = int(((n_block_x + alignment_x - 1) // alignment_x) * alignment_x)
        n_block_y = int(((n_block_y + alignment_y - 1) // alignment_y) * alignment_y)
462

463
464
465
466
467
468
        first_dim_scale_shape = self._apply_scale_shape_correction(
            data_shape[:flatten_axis], n_block_x, block_x
        )
        last_dim_scale_shape = self._apply_scale_shape_correction(
            data_shape[flatten_axis:], n_block_y, block_y
        )
469

470
        return (*first_dim_scale_shape, *last_dim_scale_shape)
471

472
473
474
475
476
477
478
479
480
481
482
    @lru_cache(maxsize=4)
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        # If we need to support 1x1x for inference in the future
483
        # if get_quantize_config().INFERENCE_MODE:
484
485
486
487
488
489
490
491
492
        #     assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!")
        #     if usage == TensorUsage.LHS:
        #         return QuantizeLayout.ROWWISE
        #     return QuantizeLayout.COLWISE

        if usage in (TensorUsage.LHS, TensorUsage.RHS_TRANS):
            return QuantizeLayout.ROWWISE
        return QuantizeLayout.COLWISE

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
        """Get the shape for grouped scale tensors in this mode.
        If padded: The estimiated maximal possible shape for grouped scale tensor is return instead.

        Args:
            data_shape: Original shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            The shape for scale tensors
        """
        assert isinstance(n_groups, int)
        block_alignment = self._block_alignment if is_padded else (1, 1)

        if is_colwise:
            block_y, block_x = self._block_dims
            alignment_y, alignment_x = block_alignment
        else:
            block_x, block_y = self._block_dims
            alignment_x, alignment_y = block_alignment

        if flatten_axis < 0:
            flatten_axis = len(data_shape) + flatten_axis
        assert (
            0 < flatten_axis < len(data_shape)
        ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"

        assert data_shape[flatten_axis - 1] % block_x == 0, (
            f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
            f" {flatten_axis - 1}"
        )
        assert (
            data_shape[-1] % block_y == 0
        ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"

        flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
        flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)

        assert flattened_first_dim % block_x == 0, (
            f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
            f" {data_shape} - should be divisible by block_x {block_x}"
        )
        assert flattened_last_dim % block_y == 0, (
            "Flattened last dim - mutiplication of"
            f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
            f" divisible by block_y {block_y}"
        )

        n_block_x = int(flattened_first_dim // block_x)
        n_block_y = int(flattened_last_dim // block_y)

        """
            Given the scale shape of [M, N], and G groups, and padding alignment (128, 4),
            The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows.
            Then:
                max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128
                max_padded_cols = DIVUP(N, 4) * 4
                max_scale_size = max_padded_rows * max_padded_cols
        """
        if is_padded:
            n_block_x = (n_groups - 1) * alignment_x + DIVUP(
                n_block_x - n_groups + 1, alignment_x
            ) * alignment_x
            n_block_y = DIVUP(n_block_y, alignment_y) * alignment_y

        return (n_block_x * n_block_y,)

564
565
566
567
568
569
570
571
572
573
574
575
    def get_shardy_sharding_rules(
        self, input_rank, unique_var, flatten_axis
    ) -> QuantizeShardyRules:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
            input_rank: The rank of the input tensor (for which we produce the scale tensor)
            unique_var: An otherwise unused Shardy variable name prefix

        Returns:
            The Shardy rules for the scaling mode
        """
Alp Dener's avatar
Alp Dener committed
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
        del flatten_axis
        input_spec = [f"{unique_var}{i}" for i in range(input_rank)]
        rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)]
        colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)]

        # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors.
        #             Unfortunately, because Shardy rules are applied to the inner primitive, the
        #             only way to preserve the relationship is to lower unpadded scales to the
        #             underlying custom call and pad them in C++. Until that's implemented, the
        #             Shardy rules for block scales have to be completely disconnected from the
        #             Shardy rules for the tensor they belong to.

        # # We have to use two different factors in the two CompoundFactors because of Shardy
        # # verifier requirements, even though they are the same.
        # rowwise_var = unique_var
        # colwise_var = f"{unique_var}_"
        # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise")
        # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise")

        # # The rowwise and colwise scale tensors should be sharded the same way as the input.
        # # However, we need to adjust the dimensions where the block scaling factor applies.
        # rowwise = input_spec.copy()
        # rowwise[-1] = rowwise_var

        # colwise = input_spec.copy()
        # colwise[flatten_axis - 1] = colwise_var

        # # This implementation needs to be updated for different block dims.
        # assert self._block_dims == (1, 32)
605
606
607
608
609

        return QuantizeShardyRules(
            tuple(input_spec),
            tuple(rowwise),
            tuple(colwise),
Alp Dener's avatar
Alp Dener committed
610
            {},  # {"block_size_rowwise": 32, "block_size_colwise": 32},
611
612
        )

613
614
615
616
617
618
619

@dataclass(frozen=True)
@register_pytree_node_class
class ScalingMode(Enum):
    """Enumeration of tensor scaling modes with their corresponding metadata implementations.

    This class defines the available scaling modes for tensor quantization:
620
621
    - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
    - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
622
    - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
623
    - NO_SCALING: No scaling applied
624
625
    """

626
627
628
    NO_SCALING = JAXX_Scaling_Mode.NO_SCALING
    DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING
    MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING
629
    CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652

    def _get_impl(self) -> ScalingModeMetadataImpl:
        """Get the implementation for this scaling mode.

        Returns:
            The scaling mode implementation

        Raises:
            ValueError: If the scaling mode is invalid
        """
        impl = SCALING_MODES_TO_IMPL.get(self)
        if impl is None:
            raise ValueError("Invalid scaling mode")
        return impl

    def get_scale_dtype(self):
        """Get the data type for scale tensors in this mode.

        Returns:
            The data type for scale tensors
        """
        return self._get_impl().get_scale_dtype()

653
    def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]:
654
655
656
657
658
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            is_padded: Whether to use padded shapes
659
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
660
661
662
663
664

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        rowwise_scale_shape = self.get_scale_shape(
665
666
667
668
            data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis
        )
        colwise_scale_shape = self.get_scale_shape(
            data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis
669
670
671
        )
        return (rowwise_scale_shape, colwise_scale_shape)

672
673
674
    def get_scale_shape(
        self, data_shape, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[int]:
675
676
677
678
679
680
        """Get the shape for scale tensors in this mode.

        Args:
            data_shape: Shape of the data tensor
            is_colwise: Whether to use column-wise scaling
            is_padded: Whether to use padded shapes
681
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
682
683
684
685

        Returns:
            The shape for scale tensors
        """
686
        return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis)
687

688
689
690
691
692
693
694
695
696
697
698
    def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout:
        """Get the quantize layout for the tensor usage.

        Args:
            usage: The usage of the tensor

        Returns:
            The quantize layout for the tensor usage
        """
        return self._get_impl().get_quantize_layout(usage)

699
700
701
702
703
704
705
706
707
708
709
710
711
712
    def get_shardy_sharding_rules(
        self, input_rank, unique_var, flatten_axis=-1
    ) -> Tuple[Tuple[str]]:
        """Sharding rules for the input and (row, col)wise scale tensors.

        Args:
            input_rank: The rank of the input tensor (for which we produce the scale tensor)
            unique_var: An otherwise unused Shardy variable name prefix

        Returns:
            The Shardy rules for the scaling mode
        """
        return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
    def get_grouped_scale_shape_2x(
        self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
    ) -> Tuple[Tuple[int]]:
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            n_groups: Number of groups for grouped quantization
            group_axis: The axis along which grouping is performed
            is_padded: Whether to use padded shapes
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        rowwise_scale_shape = self.get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=False,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )
        colwise_scale_shape = self.get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=True,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )
        return (rowwise_scale_shape, colwise_scale_shape)

    def get_grouped_scale_shape(
        self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
    ) -> Tuple[Tuple[int]]:
        """Get shapes for both row-wise and column-wise scaling.

        Args:
            data_shape: Shape of the data tensor
            is_padded: Whether to use padded shapes
            flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
        return self._get_impl().get_grouped_scale_shape(
            data_shape,
            n_groups,
            group_axis,
            is_colwise=is_colwise,
            is_padded=is_padded,
            flatten_axis=flatten_axis,
        )

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    def is_tensor_scaling(self) -> bool:
        """Check if this scaling mode is per-tensor scaling.

        Returns:
            True if the scaling mode is tensor scaling, False otherwise
        """
        return self in (
            ScalingMode.DELAYED_TENSOR_SCALING,
            ScalingMode.CURRENT_TENSOR_SCALING,
        )

    def is_1d_block_scaling(self) -> bool:
        """Check if this scaling mode is 1D block scaling.

        Returns:
            True if the scaling mode is 1D block scaling, False otherwise
        """
        return self == ScalingMode.MXFP8_1D_SCALING

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
    def __eq__(self, other):
        """Compare this scaling mode with another.

        Args:
            other: The other scaling mode to compare with

        Returns:
            True if the modes are equal, False otherwise
        """
        if not isinstance(other, ScalingMode):
            return False
        return self.value == other.value

    def tree_flatten(self):
        """Flatten this scaling mode for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        return (), (self.value)

    @classmethod
    def tree_unflatten(cls, aux_data, _children):
        """Reconstruct a scaling mode from its flattened representation.

        Args:
            aux_data: Auxiliary data containing the mode value
            _children: Unused children data

        Returns:
            A reconstructed ScalingMode instance
        """
        return cls(aux_data)


SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
823
824
    ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(),
    ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)),
825
    # WAR
826
    ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(),
827
    ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(),
828
}