quantizer.py 46.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor quantization classes for TE/JAX.

This module provides classes and utilities for quantizing tensors in JAX.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
12
13
from typing import Union, Optional, Tuple
import warnings
14
15
16
17

import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
18
from transformer_engine.common import recipe
19
20

from .scaling_modes import ScalingMode
21
from .misc import QuantizeLayout
22
from .hadamard import apply_rht
23
24
25
26
27
28
29
from .tensor import (
    ScaledTensor,
    ScaledTensor1x,
    ScaledTensor2x,
    ScaledTensorFactory,
    NoScaleTensor,
)
30
from .helper import (
31
    get_quantize_config,
32
    get_quantize_config_with_recipe,
33
    AmaxComputeAlgo,
34
    TensorSource,
35
)
36
from .device_utils import is_fp8_gemm_with_all_layouts_supported
37
from ..sharding import get_num_devices_in_mesh
38
39
40
41

__all__ = [
    "Quantizer",
    "QuantizerSet",
42
    "CurrentScaleQuantizer",
43
44
    "DelayedScaleQuantizer",
    "BlockScaleQuantizer",
45
    "GroupedQuantizer",
46
47
    "QuantizerFactory",
    "noop_quantizer_set",
48
    "compute_scale_from_amax",
49
50
51
]


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def compute_scale_from_amax(
    amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
    """Compute scale from amax value.

    Args:
        amax: Maximum absolute value of the tensor
        q_dtype: Quantization data type

    Returns:
        Scale value
    """
    fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
    if scale is None:
        scale = jnp.ones((1,))
67
    sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
68
69
    sf = jnp.where(amax > 0.0, sf, scale)
    sf = jnp.where(jnp.isfinite(amax), sf, scale)
70
    assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}"
71
72
73
    return sf


74
75
76
77
78
79
80
81
82
83
84
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
    """Base class for quantizers.

    This abstract class defines the interface for tensor quantization, providing
    methods for quantization and scale management.

    Attributes:
        q_dtype: The data type for quantized values
        scaling_mode: The scaling mode to use for quantization
85
        q_layout: The quantization axis (row-wise, column-wise, or both)
86
87
        data_layout: The data layout string (e.g., "NT")
        checkpoint_name: Optional name for checkpointing quantization state
88
89
90
91
    """

    q_dtype: jnp.dtype
    scaling_mode: ScalingMode
92
    q_layout: QuantizeLayout
93
    data_layout: str
94
    checkpoint_name: Optional[str] = None
95
96
97
98
99
100
101
102

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = ()
103
104
105
106
107
108
109
        aux_data = (
            self.q_dtype,
            self.scaling_mode,
            self.q_layout,
            self.data_layout,
            self.checkpoint_name,
        )
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer from its flattened representation.

        Args:
            aux_data: Auxiliary data containing quantizer parameters
            children: Unused children data

        Returns:
            A reconstructed Quantizer instance
        """
        return cls(*aux_data, *children)

    def update(self, *args, **kwargs):
        """Update quantizer state (no-op in base class)."""
        del args, kwargs

129
    def get_data_layout(self) -> str:
130
        """Get the data data_layout string.
131
132

        Returns:
133
            Data data_layout in string format
134
135
136

        Raises:
            ValueError: If quantization axis is invalid
137
        """
138
        if self.q_layout.is_rowwise_colwise:
139
            return self.data_layout
140
        if self.q_layout.is_rowwise_only:
141
            return self.data_layout[0]
142
        if self.q_layout.is_colwise_only:
143
144
            return self.data_layout[1]
        raise ValueError(f"Invalid q_layout: {self.q_layout}")
145
146

    @abstractmethod
147
    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
148
149
150
151
152
153
        """Core quantization function to be implemented by subclasses.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values, default is x.dtype
154
            flatten_axis: The quantization axis for the tensor
155
156
157
158
159

        Returns:
            A ScaledTensor1x containing the quantized data
        """

160
    def quantize(
161
        self, x, is_rowwise=None, is_colwise=None, dq_dtype=None, flatten_axis=-1, **kwargs
162
    ) -> ScaledTensor:
163
164
165
166
167
168
169
        """Quantize a tensor using the internal _quantize_func().

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
170
            flatten_axis: The quantization axis for the tensor
171
172
173
174

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
175
        del kwargs
176

177
178
        is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
        is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
179

180
        if is_rowwise and is_colwise:
181
182
183
184
            rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
            colwise_tensor = self._quantize_func(
                x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
            )
185
186
187
            return ScaledTensor2x(rowwise_tensor, colwise_tensor)

        if is_colwise:
188
189
190
            return self._quantize_func(
                x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
            )
191

192
        return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
193

194
    def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, **kwargs):
195
196
197
198
199
200
201
202
203
        """Get shapes for scale tensors.

        Args:
            data_shape: Shape of the input tensor
            is_padded: Whether to use padded shapes

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
204
        del kwargs
205
        return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
206
207
208
209
210
211
212
213
214
215
216
217

    def get_scale_dtype(self):
        """Get the data type for scale tensors.

        Returns:
            The data type for scale tensors
        """
        return self.scaling_mode.get_scale_dtype()


@register_pytree_node_class
@dataclass
218
219
class CurrentScaleQuantizer(Quantizer):
    """Quantizer implementation using current scaling.
220

221
    This quantizer uses current scaling mode with float32 scales
222
223
224

    Attributes:
        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
225
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
226
227
    """

228
    scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
229
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
230
    data_layout: str = "NT"
231
232

    def _quantize_func(
233
234
235
236
237
        self,
        x: Union[jnp.ndarray, NoScaleTensor],
        is_colwise=False,
        dq_dtype=None,
        flatten_axis=-1,
238
    ) -> ScaledTensor1x:
239
240
241
242
243
244
        """Quantize function helper for delayed scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
245

246
247
248
        Returns:
            A ScaledTensor1x containing the quantized data
        """
249
250
251
252
        if isinstance(x, jnp.ndarray):
            x = NoScaleTensor(data=x, amax=None)

        dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
253

254
        compute_dtype = jnp.float32
255
        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
256
        amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
257
        fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
258
        scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN)
259
        scaled_x = x.data.astype(compute_dtype) * scale
260
261

        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
262
        scale_inv = 1.0 / scale
263
264
265
266
267
        return ScaledTensorFactory.create_1x(
            data=clipped_scaled_x,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
268
            flatten_axis=flatten_axis,
269
270
        )

271
272
273
    def quantize(
        self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
    ):
274
275
276
277
278
279
280
        """Quantize a tensor using the internal _quantize_func().

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
281
            flatten_axis: The quantization axis for the tensor
282
283
284
285

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
286
287
288
289
        if isinstance(x, jnp.ndarray):
            x = NoScaleTensor(data=x, amax=None)

        dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
290
291
292
293
        if flatten_axis < 0:
            flatten_axis += x.ndim
        assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"

294
295
        is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
        is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
296

297
        rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
298
299
300
        colwise_tensor = None
        if is_colwise:
            colwise_tensor = ScaledTensorFactory.create_1x(
301
302
303
                data=jnp.transpose(
                    rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
                ),
304
305
306
307
                scale_inv=rowwise_tensor.scale_inv,
                scaling_mode=self.scaling_mode,
                dq_dtype=dq_dtype,
                is_colwise=True,
308
309
                data_layout="T",
                flatten_axis=flatten_axis,
310
            )
311

312
313
314
315
316
317
        if is_colwise and is_rowwise:
            return ScaledTensor2x(rowwise_tensor, colwise_tensor)
        if is_colwise:
            return colwise_tensor
        return rowwise_tensor

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(CurrentScaleQuantizer):
    """Quantizer implementation using delayed scaling.

    This quantizer uses delayed scaling mode with float32 scales and maintains
    a history of maximum absolute values for dynamic scaling.

    Attributes:
        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
        scale: Current scaling factor
        amax_history: History of maximum absolute values
    """

    scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE

    scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
    amax_history: jnp.ndarray = field(
339
        default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32)
340
341
342
343
344
345
346
347
348
    )

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.scale, self.amax_history)
349
350
351
352
353
354
355
        aux_data = (
            self.q_dtype,
            self.scaling_mode,
            self.q_layout,
            self.data_layout,
            self.checkpoint_name,
        )
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        return (children, aux_data)

    def _quantize_func(
        self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
    ) -> ScaledTensor1x:
        """Quantize function helper for delayed scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
            flatten_axis: The quantization axis for the tensor
        Returns:
            A ScaledTensor1x containing the quantized data
        """
371
372
373
374
        if isinstance(x, jnp.ndarray):
            x = NoScaleTensor(data=x, amax=None)

        dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype
375
376
377

        compute_dtype = jnp.float32
        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
378
        scaled_x = x.data.astype(compute_dtype) * self.scale
379
380
381
382
383
384
385
386

        # quantize() in the old dot.py do this way, leave this code block here for future debugging
        # compute_dtype = x.dtype
        # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
        # scaled_x = x * self.scale.astype(compute_dtype)

        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
        scale_inv = 1.0 / self.scale
387
        amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,))
388
        # Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization.
389
        self.update(amax)
390
391
392
393
394
395
396
397
        return ScaledTensorFactory.create_1x(
            data=clipped_scaled_x,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
            flatten_axis=flatten_axis,
        )

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
    @staticmethod
    @jax.jit
    def _update_amax_history(amax_history, new_amax):
        """Update AMAX history with new maximum value.

        Args:
            amax_history: Current AMAX history
            new_amax: New maximum value to add

        Returns:
            Updated AMAX history
        """
        amax_history = amax_history.at[0].set(new_amax[0])
        return amax_history

    @staticmethod
    @partial(jax.jit, static_argnums=(2,))
    def _compute_scale(amax_history, scale, q_dtype):
        """Compute new scale based on AMAX history.

        Args:
            amax_history: History of maximum absolute values
            scale: Current scale
            q_dtype: Quantization data type

        Returns:
            Updated scale value
        """
        # 2. Calculate the current scale
427
        if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
428
429
430
431
            amax = jnp.max(amax_history, axis=-1, keepdims=True)
        else:
            amax = amax_history[0:1]

432
        return compute_scale_from_amax(amax, q_dtype, scale=scale)
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

    @staticmethod
    @jax.jit
    def _roll_and_reset_amax_history(amax_history):
        """Roll AMAX history and reset first element.

        Args:
            amax_history: Current AMAX history

        Returns:
            Updated AMAX history
        """
        updated_amax_history = jnp.roll(amax_history, -1, -1)
        amax_history = updated_amax_history.at[0].set(0.0)
        return amax_history

    def update(self, new_amax: jnp.ndarray):
        """Update AMAX history and compute new scale.

        Args:
            new_amax: New maximum absolute value to add to history
        """
        amax_history = self._update_amax_history(self.amax_history, new_amax)
        self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
        self.amax_history = self._roll_and_reset_amax_history(amax_history)


@register_pytree_node_class
@dataclass
class BlockScaleQuantizer(Quantizer):
    """Quantizer implementation using block-based scaling.

    This quantizer uses block scaling mode with FP8 scales and block-based
    quantization for improved efficiency.

    Attributes:
        scaling_mode: Set to NVTE_MXFP8_1D_SCALING
470
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
471
472
    """

473
    scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
474
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
475
    data_layout: str = "NN"
476

477
    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
478
479
480
481
482
483
        """Quantize function helper for block scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
484
            flatten_axis: The quantization axis for the tensor
485
486
487
488

        Returns:
            A ScaledTensor1x containing the quantized data
        """
489
490
491
492
        if isinstance(x, NoScaleTensor):
            # No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x.
            x = x.data

493
        # TODO(Phuong): use quantize_func from JAX
494
495
496
497
498
499
        if flatten_axis < 0:
            flatten_axis = x.ndim + flatten_axis
        assert (
            0 <= flatten_axis < x.ndim
        ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"

500
501
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
        x_shape = x.shape
502
        scale_shape = self.scaling_mode.get_scale_shape(
503
            x_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis
504
        )
505
506
        scale_dtype = self.scaling_mode.get_scale_dtype()
        x = x.reshape(
507
508
509
510
            *x_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *x_shape[flatten_axis:-1],
511
512
513
            scale_shape[-1],
            int(x_shape[-1] / scale_shape[-1]),
        )
514
        amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
515
516
517
518
519
520
521
522
523
524
525
526
527
        MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
        scales = amax.astype(jnp.float32) / MAX

        scales_q = self._cast_to_e8m0_with_rounding_up(scales)
        scaled_x = x / self._e8m0_to_dtype(scales_q, jnp.float32)

        clipped_x = jnp.clip(scaled_x, -MAX, MAX)
        x_q = clipped_x.astype(self.q_dtype).reshape(x_shape)
        scales_q = scales_q.reshape(scale_shape).view(scale_dtype)

        return ScaledTensorFactory.create_1x(
            x_q,
            scales_q,
528
            scaling_mode=self.scaling_mode,
529
530
            is_colwise=is_colwise,
            dq_dtype=dq_dtype,
531
            flatten_axis=flatten_axis,
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        )

    def _cast_to_e8m0_with_rounding_up(self, scales):
        """Cast scales to E8M0 format with rounding up.

        Args:
            scales: Input scales to convert

        Returns:
            Scales in E8M0 format
        """
        temp = scales.astype(jnp.float32).view(jnp.uint32)
        exp = temp >> 23
        mant = temp & 0x7FFFFF
        is_ru = jnp.logical_and(
            jnp.logical_and((mant > 0), (exp != 0xFE)),
            ~jnp.logical_and((exp == 0), (mant <= 0x400000)),
        )
        exp = jnp.where(is_ru, exp + 1, exp)
        new_scales = exp.astype(jnp.uint8)
        return new_scales

    def _e8m0_to_dtype(self, x, dtype):
        """Convert E8M0 format to specified data type.

        Args:
            x: Input in E8M0 format
            dtype: Target data type

        Returns:
            Converted values in target data type
        """
        temp = x.astype(jnp.uint32)
        exp = temp << 23
        new_x = exp.view(jnp.float32)
        near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
        new_x = jnp.where(new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x)
        return new_x.astype(dtype)


572
573
574
575
576
577
578
579
580
581
582
583
@register_pytree_node_class
@dataclass
class NVFP4Quantizer(Quantizer):
    """Quantizer implementation using current scaling.

    This quantizer uses current scaling mode with float32 scales

    Attributes:
        scaling_mode: Set to NVFP4_1D_SCALING or NVFP4_2D_SCALING
        q_layout: Quantization axis
        data_layout: Data layout string (default: "NT")
        stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled.
584
        use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization.
585
586
587
588
589
    """

    scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
    data_layout: str = "NT"
590
    use_rht: bool = False
591
592
593
594
595
596
597
598
    stochastic_rounding_rng_state: Optional[jnp.ndarray] = None

    def __post_init__(self):
        assert (
            self.q_dtype == jnp.float4_e2m1fn
        ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn"
        assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes"

599
600
601
602
603
604
605
    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.stochastic_rounding_rng_state,)
606
607
608
609
610
611
612
613
        aux_data = (
            self.q_dtype,
            self.scaling_mode,
            self.q_layout,
            self.data_layout,
            self.checkpoint_name,
            self.use_rht,
        )
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer from its flattened representation.

        Args:
            aux_data: Auxiliary data containing quantizer parameters
            children: Unused children data

        Returns:
            A reconstructed Quantizer instance
        """
        stochastic_rounding_rng_state = children[0]
        return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state)

630
631
632
633
    def _apply_stochastic_rounding(self, x):
        assert (
            self.stochastic_rounding_rng_state is not None
        ), "Stochastic rounding RNG state is not initialized"
634
635
636
637
638
        expected_sr_rng_state_shape = (get_num_devices_in_mesh(), 4)
        assert self.stochastic_rounding_rng_state.shape == expected_sr_rng_state_shape, (
            "Stochastic rounding RNG state must be of shape (num_devices_in_mesh, 4). Expected"
            f" {expected_sr_rng_state_shape}, but got {self.stochastic_rounding_rng_state.shape}"
        )
639
640
641
642
643
644
645
        assert (
            self.stochastic_rounding_rng_state.dtype == jnp.uint32
        ), "Stochastic rounding RNG state must be of dtype uint32"

        # Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s
        key_bits = jnp.array(
            [
646
647
648
                # only take the first device's RNG state as the pure-JAX stochastic rounding impl only uses a single-device
                self.stochastic_rounding_rng_state[0][0],
                self.stochastic_rounding_rng_state[0][1],
649
650
651
652
            ],
            dtype=jnp.uint32,
        )
        key = jax.random.wrap_key_data(key_bits)
653
654
        key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][2])
        key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][3])
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717

        abs_x = jnp.abs(x)
        sign_x = jnp.sign(x)

        floor = (
            (abs_x >= 0.5) * 0.5
            + (abs_x >= 1) * 0.5
            + (abs_x >= 2)
            + (abs_x >= 3)
            + (abs_x >= 4)
            + (abs_x >= 6) * 2
        )
        ceil = (
            0.5
            + (abs_x > 0.5) * 0.5
            + (abs_x > 1) * 1
            + (abs_x > 2)
            + (abs_x > 3)
            + (abs_x > 4) * 2
        )
        frac = (abs_x - floor) / (ceil - floor)

        rand = jax.random.uniform(key, abs_x.shape)
        return sign_x * jnp.where(frac >= rand, ceil, floor)

    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
        """Quantize function helper for block scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
            flatten_axis: The quantization axis for the tensor

        Returns:
            A ScaledTensor1x containing the quantized data
        """
        # TODO(Phuong): use quantize_func from JAX
        if flatten_axis < 0:
            flatten_axis = x.ndim + flatten_axis
        assert (
            0 <= flatten_axis < x.ndim
        ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"

        should_apply_rht = self.scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise

        global_amax = None
        if isinstance(x, NoScaleTensor):
            global_amax = (
                x.amax if not should_apply_rht else None
            )  # RHT changes the amax so don't use precalculated amax for colwise 1D nvfp4 quantization with RHT
            x = x.data

        # Transpose if required
        rowwise_flatten_axis = flatten_axis
        data_layout = self.data_layout[0]
        if is_colwise:
            x = jnp.transpose(x, (*range(flatten_axis, x.ndim), *range(flatten_axis)))
            data_layout = self.data_layout[1]
            # convert flatten_axis from N layout to T layout
            flatten_axis = x.ndim - flatten_axis
        x_shape = x.shape

718
719
720
        # We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now.
        use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING
        if use_rht:
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
            x = apply_rht(x)

        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
        scale_shape = self.scaling_mode.get_scale_shape(
            x_shape,
            data_layout=data_layout,
            is_colwise=is_colwise,
            is_padded=False,
            flatten_axis=rowwise_flatten_axis,
        )
        scale_dtype = self.scaling_mode.get_scale_dtype()
        x = x.reshape(
            *x_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *x_shape[flatten_axis:-1],
            scale_shape[-1],
            int(x_shape[-1] / scale_shape[-1]),
        )

        # Dtype max constants
        DATA_DTYPE_MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
        SCALE_DTYPE_MAX = jnp.finfo(scale_dtype).max.astype(jnp.float32)

        # Level 1: Current Tensor Scaling
        global_amax = (
            global_amax
            if global_amax is not None
            else jnp.max(jnp.abs(x)).reshape((1,)).astype(jnp.float32)
        )
        tensor_scale = DATA_DTYPE_MAX * SCALE_DTYPE_MAX / global_amax
        tensor_scale = jnp.minimum(
            tensor_scale, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32)
        )
        tensor_scale = jnp.where(
            tensor_scale == jnp.array(0.0, dtype=jnp.float32),
            jnp.array(1.0, dtype=jnp.float32),
            tensor_scale,
        )
        tensor_scale_inv = 1.0 / tensor_scale

        # Level 2: Block Scaling
        block_amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True).astype(
            jnp.float32
        )
        block_scale_inv = jnp.divide(block_amax, DATA_DTYPE_MAX)
        block_scale_inv = block_scale_inv * tensor_scale
        block_scale_inv = jnp.minimum(
            block_scale_inv, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32)
        )
        block_scale_inv = jnp.clip(block_scale_inv, -SCALE_DTYPE_MAX, SCALE_DTYPE_MAX)
        # We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate.
        block_scale_inv = block_scale_inv.astype(scale_dtype)
        # Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype.
        assert scale_dtype == jnp.float8_e4m3fn, "Only float8_e4m3fn is supported for scale_dtype"
        block_scale_inv = jax.lax.reduce_precision(block_scale_inv, 4, 3)
        block_scale = jnp.minimum(
            jnp.divide(1.0, block_scale_inv.astype(jnp.float32) * tensor_scale_inv),
            jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32),
        )

        # Apply scaling
        scaled_x = x.astype(jnp.float32) * block_scale
        if self.stochastic_rounding_rng_state is not None:
            scaled_x = self._apply_stochastic_rounding(scaled_x)
        clipped_x = jnp.clip(scaled_x, -DATA_DTYPE_MAX, DATA_DTYPE_MAX)

        # Cast to the right dtype
        quantized_data = clipped_x.reshape(x_shape).astype(self.q_dtype)
        block_scale_inv = block_scale_inv.reshape(scale_shape).astype(scale_dtype)

        # In the 2D scaling mode, the scale shape is 2D but it needs to be broadcasted to 1D for GEMM.
        # TODO(Phuong): expose this broadcast_2d_scale_shape_to_1d option to the
        # quantizer.quantize() API
        broadcasted_1d_scale_shape = self.scaling_mode.get_scale_shape(
            x_shape,
            data_layout=data_layout,
            is_colwise=is_colwise,
            is_padded=False,
            flatten_axis=rowwise_flatten_axis,
            broadcast_2d_scale_shape_to_1d=True,
        )

        # Broadcast and tile x to match the target shape
        def repeat_to_shape(x, target_shape):
            x_shape = x.shape
            reps = [int(t // s) for s, t in zip(x_shape, target_shape)]
            return jnp.tile(x, reps)

        block_scale_inv = repeat_to_shape(block_scale_inv, broadcasted_1d_scale_shape)

        return ScaledTensorFactory.create_1x(
            data=quantized_data,
            data_layout=data_layout,
            is_colwise=is_colwise,
            scale_inv=block_scale_inv,
            amax=global_amax,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
            flatten_axis=rowwise_flatten_axis,
821
            has_rht_applied=use_rht,
822
823
824
        )


825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
@register_pytree_node_class
@dataclass
class QuantizerSet:
    """Set of quantizers for different tensor types.

    This class manages quantizers for input tensors, kernel tensors, and
    gradient tensors.

    Attributes:
        x: Quantizer for input tensors
        kernel: Quantizer for kernel tensors
        dgrad: Quantizer for gradient tensors
    """

    x: Optional[Quantizer]
    kernel: Optional[Quantizer]
    dgrad: Optional[Quantizer]

    def tree_flatten(self):
        """Flatten the quantizer set for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.x, self.kernel, self.dgrad)
        aux_data = ()
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer set from its flattened representation.

        Args:
            aux_data: Unused auxiliary data
            children: Tuple of quantizers

        Returns:
            A reconstructed QuantizerSet instance
        """
        return cls(*aux_data, *children)


867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
@register_pytree_node_class
@dataclass
class GroupedQuantizer(Quantizer):
    """Quantizer for grouped arrays.

    This class extends Quantizer to support quantization of arrays in grouped manner,
    where elements are grouped along a specified axis then quantized separately.

    Attributes:
        data_layout: The data layout specification
        n_groups: Number of groups for quantization
        quantizers: Tuple of quantizers for each group
    """

    data_layout: str = None
    n_groups: int = 1
    quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,))

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.quantizers,)
892
893
894
895
896
897
898
899
        aux_data = (
            self.q_dtype,
            self.scaling_mode,
            self.q_layout,
            self.data_layout,
            self.checkpoint_name,
            self.n_groups,
        )
900
901
902
903
        return (children, aux_data)

    def __post_init__(self):
        if self.quantizers[0] is None:
904
            quantizers = QuantizerFactory.create(
905
906
                self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
            )
907
            self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
        self.data_layout = self.quantizers[0].data_layout

    def _create_grouped_tensor_from_tensor_list(
        self, tensor_list, group_sizes, original_shape, group_axis, mode
    ):
        # mode 0 = concate, mode 1 = add
        # TODO(Ming Huang): Consider to apply Enum for mode.
        assert mode in [0, 1]
        grouped_data = (
            [] if mode == 0 else jnp.zeros(tensor_list[0].data.shape, tensor_list[0].data.dtype)
        )
        grouped_scale_inv = []

        for tensor in tensor_list:
            if mode == 0:
                grouped_data.append(tensor.data.flatten())
            else:
                grouped_data += tensor.data
            grouped_scale_inv.append(tensor.scale_inv.flatten())

        grouped_data = jnp.concatenate(grouped_data) if mode == 0 else grouped_data.flatten()
        grouped_scale_inv = jnp.concatenate(grouped_scale_inv)

        return ScaledTensorFactory.create_1x(
            grouped_data,
            grouped_scale_inv,
934
935
936
937
938
            scaling_mode=self.scaling_mode,
            dq_dtype=tensor_list[0].dq_dtype,
            is_colwise=tensor_list[0].is_colwise,
            data_layout=tensor_list[0].data_layout,
            flatten_axis=tensor_list[0].flatten_axis,
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
        )

    def _quantize_func(self, *args, **kwargs):
        pass

    def quantize(
        self,
        x,
        is_rowwise: bool = None,
        is_colwise: bool = None,
        dq_dtype=None,
        flatten_axis=-1,
        group_sizes=None,
        group_axis=0,
    ):
        """Quantize a tensor in grouped manner.

        Expected input shape: [M, K] or [G, K, N]
        Split to x.shape[group_axis] number of groups if group_sizes is not given

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array of ints containing the size of each group (default: None)
            group_axis: The axis along which grouping is performed (default: 0)

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
        assert group_axis == 0, "Only group_axis == 0 is supported now!"

        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
        if flatten_axis < 0:
            flatten_axis += x.ndim
        assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"

981
982
        is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise
        is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        assert is_rowwise or is_colwise, "No quantization layout is specified"

        original_shape = x.shape

        if group_sizes is not None:
            assert not is_colwise, "Not yet implememted!"
            assert group_sizes.ndim == 1, (
                "GroupedQuantizer only support 1D group_sizes, got group_sizes.ndim ="
                f" {group_sizes.ndim}"
            )

            _zeros = partial(jax.lax.full_like, fill_value=0)

            x_iota = jax.lax.broadcasted_iota(group_sizes.dtype, x.shape, 0)
            group_ends = jnp.cumulative_sum(group_sizes)
            group_starts = jax.lax.concatenate(
                [_zeros(group_sizes)[:1], group_ends[:-1]],
                dimension=0,
            )
            x_zero = _zeros(x)

            tensor_list = []
            for i in range(len(group_sizes)):
                mask = jax.lax.bitwise_and(group_starts[i] <= x_iota, x_iota < group_ends[i])
                x_selected = jax.lax.select(mask, x, x_zero)
                tensor = self.quantizers[i].quantize(
                    x_selected, is_rowwise, is_colwise, dq_dtype, flatten_axis
                )
                tensor_list.append(tensor)
            combine_mode = 1  # Add
        else:
            group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)
            x = jnp.split(x, x.shape[group_axis], axis=group_axis)

            tensor_list = []
            for i in range(len(group_sizes)):
                tensor = self.quantizers[i].quantize(
                    x[i], is_rowwise, is_colwise, dq_dtype, flatten_axis
                )
                tensor_list.append(tensor)
            combine_mode = 0  # Concate

        grouped_rowwise_tensor = grouped_colwise_tensor = None
        if is_rowwise:
            rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list]
            grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list(
                rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
            )
        if is_colwise:
            colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list]
            grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list(
                colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
            )

        if is_colwise and is_rowwise:
            return ScaledTensor2x(grouped_rowwise_tensor, grouped_colwise_tensor)
        if is_colwise:
            return grouped_colwise_tensor
        return grouped_rowwise_tensor

    def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, group_sizes=None):
        assert group_sizes, "Empty group_sizes was given!"
        return self.scaling_mode.get_grouped_scale_shape_2x(
            data_shape, group_sizes, is_padded, flatten_axis
        )


1050
1051
1052
1053
1054
1055
1056
1057
1058
@dataclass
class QuantizerFactory:
    """Factory class for creating quantizers.

    This class provides static methods to create individual quantizers and
    sets of quantizers with various configurations.
    """

    quantizer_type_map = {
1059
        ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
1060
        ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
1061
        ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
1062
1063
        ScalingMode.NVFP4_1D_SCALING: NVFP4Quantizer,
        ScalingMode.NVFP4_2D_SCALING: NVFP4Quantizer,
1064
1065
1066
1067
1068
1069
1070
    }

    @staticmethod
    def create(
        n_quantizers: int = 1,
        scaling_mode: ScalingMode = None,
        q_dtype: jnp.dtype = None,
1071
        q_layout: QuantizeLayout = None,
1072
        n_groups: int = None,
1073
        checkpoint_name: Optional[str] = None,
1074
1075
1076
1077
1078
1079
1080
1081
        **kwargs,
    ) -> Quantizer:
        """Create one or more quantizers with specified parameters.

        Args:
            n_quantizers: Number of quantizers to create
            scaling_mode: Scaling mode to use
            q_dtype: Quantization data type
1082
1083
            q_layout: Quantization axis
            flatten_axis: The quantization axis for the tensor
1084
            n_groups: Number of quantizers if GroupedQuantizer
1085
            checkpoint_name: Optional name for checkpointing quantizations
1086
1087
1088
1089
1090
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A single quantizer or tuple of quantizers
        """
1091
        assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
        if n_groups:
            if n_quantizers != 1:
                warnings.warn(
                    "Using more than one GroupedQuantizer for a grouped input is not recommended"
                )
            quantizer_type = GroupedQuantizer
            kwargs["n_groups"] = n_groups
        else:
            quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)

1102
        if scaling_mode == ScalingMode.NO_SCALING:
1103
1104
1105
1106
1107
1108
            quantizers = [None] * n_quantizers
        else:
            quantizers = []
            for _ in range(n_quantizers):
                quantizers.append(
                    quantizer_type(
1109
1110
1111
1112
1113
                        q_dtype=q_dtype,
                        scaling_mode=scaling_mode,
                        q_layout=q_layout,
                        checkpoint_name=checkpoint_name,
                        **kwargs,
1114
1115
1116
1117
1118
                    )
                )
        return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)

    @staticmethod
1119
    def _create_set(
1120
1121
1122
1123
1124
1125
1126
        x_scaling_mode,
        kernel_scaling_mode,
        grad_scaling_mode,
        fwd_dtype,
        bwd_dtype,
        is_2x2x,
        n_groups,
1127
        checkpoint_name: Optional[str] = None,
1128
        **kwargs,
1129
    ) -> QuantizerSet:
1130
1131
1132
        """Create a set of quantizers for forward and backward passes.

        Args:
1133
1134
1135
            x_scaling_mode: Scaling mode to use for input tensor 'x'
            kernel_scaling_mode: Scaling mode to use for kernel tensor
            grad_scaling_mode: Scaling mode to use for gradient tensor
1136
1137
1138
            fwd_dtype: Data type for forward pass
            bwd_dtype: Data type for backward pass
            is_2x2x: Whether to use 2x2x quantization
1139
            n_groups
1140
            checkpoint_name: Optional name for checkpointing quantizations
1141
1142
1143
1144
1145
1146
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A QuantizerSet instance
        """
        if is_2x2x:
1147
            q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
1148
        else:
1149
            q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
1150
            if kernel_scaling_mode.is_1d_block_scaling():
1151
                q_layout_kernel = QuantizeLayout.COLWISE
1152
            if get_quantize_config().INFERENCE_MODE:
1153
                q_layout_dgrad = None
1154
1155
1156

        if "quantize_meta_set" in kwargs:
            quantize_meta_set = kwargs.get("quantize_meta_set")
1157
1158
1159
            args_x = quantize_meta_set.x.get_kwargs_dictionary()
            args_kernel = quantize_meta_set.kernel.get_kwargs_dictionary()
            args_grad = quantize_meta_set.grad.get_kwargs_dictionary()
1160
1161
1162
        else:
            args_x = args_kernel = args_grad = {}

1163
1164
1165
1166
1167
1168
1169
1170
1171
        q_x = QuantizerFactory.create(
            1,
            x_scaling_mode,
            fwd_dtype,
            q_layout_x,
            n_groups,
            checkpoint_name=checkpoint_name,
            **args_x,
        )
1172
        q_kernel = QuantizerFactory.create(
1173
1174
1175
1176
1177
1178
1179
            1,
            kernel_scaling_mode,
            fwd_dtype,
            q_layout_kernel,
            n_groups,
            checkpoint_name=checkpoint_name,
            **args_kernel,
1180
1181
        )
        q_dgrad = QuantizerFactory.create(
1182
1183
1184
1185
1186
1187
1188
            1,
            grad_scaling_mode,
            bwd_dtype,
            q_layout_dgrad,
            n_groups,
            checkpoint_name=checkpoint_name,
            **args_grad,
1189
        )
1190
1191
1192
1193
1194
        return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)

    @staticmethod
    def create_set(
        n_quantizer_sets: int = 1,
1195
        scaling_mode: Optional[ScalingMode] = None,
1196
1197
1198
        fwd_dtype: jnp.dtype = None,
        bwd_dtype: jnp.dtype = None,
        is_2x2x: bool = None,
1199
        n_groups: int = None,
1200
        checkpoint_name: Optional[str] = None,
1201
        # TODO(jberchtold): rename fp8_recipe to quantization_recipe
1202
        fp8_recipe: Optional[recipe.Recipe] = None,
1203
1204
1205
1206
1207
1208
        **kwargs,
    ) -> tuple[Union[tuple[Quantizer], None]]:
        """Create one or more sets of quantizers.

        Args:
            n_quantizer_sets: Number of quantizer sets to create
1209
1210
1211
1212
            scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode
            fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE
            bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE
            is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X
1213
            n_groups:
1214
            checkpoint_name: Optional name for checkpointing quantizations
1215
            fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
1216
1217
1218
1219
1220
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A single quantizer set or tuple of quantizer sets
        """
1221
1222
1223
1224
1225
1226
1227
1228
1229

        assert scaling_mode is None or fp8_recipe is None, (
            "Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling"
            " mode can be specified directly via the scaling_mode parameter or indirectly via"
            " recipe. Recipe is preferred as it will support additional recipes in future where"
            " scaling mode differs between x, kernel, and grad in the quantizer set."
        )

        if fp8_recipe is not None:
1230
            quantize_config = get_quantize_config_with_recipe(fp8_recipe)
1231
1232
1233
            x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X)
            kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL)
            grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD)
1234
1235
            fwd_dtype = quantize_config.FWD_DTYPE
            bwd_dtype = quantize_config.BWD_DTYPE
1236
        else:
1237
1238
1239
1240
1241
1242
1243
1244
            if scaling_mode is not None:
                x_scaling_mode = scaling_mode
                kernel_scaling_mode = scaling_mode
                grad_scaling_mode = scaling_mode
            else:
                x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X)
                kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL)
                grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD)
1245

1246
1247
            fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE
            bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE
1248
        if is_2x2x is None:
1249
1250
            # TODO(Jeremy): check x, kernel, grad separately for 2x
            if x_scaling_mode.is_1d_block_scaling():
1251
                is_2x2x = True
1252
            elif x_scaling_mode.is_tensor_scaling():
1253
1254
1255
                is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
            else:  # NO_SCALING ignores is_2x2x for now
                is_2x2x = False
1256
        is_inference_mode = get_quantize_config().INFERENCE_MODE
1257
        assert not is_inference_mode, "Inference mode is not supported yet!"
1258
1259
1260
1261

        q_set = []
        for _ in range(n_quantizer_sets):
            q_set.append(
1262
                QuantizerFactory._create_set(
1263
1264
1265
1266
1267
1268
1269
                    x_scaling_mode=x_scaling_mode,
                    kernel_scaling_mode=kernel_scaling_mode,
                    grad_scaling_mode=grad_scaling_mode,
                    fwd_dtype=fwd_dtype,
                    bwd_dtype=bwd_dtype,
                    is_2x2x=is_2x2x,
                    n_groups=n_groups,
1270
                    checkpoint_name=checkpoint_name,
1271
                    **kwargs,
1272
                )
1273
1274
1275
1276
1277
            )

        return q_set[0] if len(q_set) == 1 else tuple(q_set)


1278
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False)