quantizer.py 24.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor quantization classes for TE/JAX.

This module provides classes and utilities for quantizing tensors in JAX.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
from typing import Union, Optional

import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
17
from transformer_engine_jax import QuantizeLayout
18
19
20
21
22
23
24
25
26

from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import (
    QuantizeConfig,
    AmaxComputeAlgo,
)

__all__ = [
27
    "QuantizeLayout",
28
29
    "Quantizer",
    "QuantizerSet",
30
    "CurrentScaleQuantizer",
31
32
33
34
    "DelayedScaleQuantizer",
    "BlockScaleQuantizer",
    "QuantizerFactory",
    "noop_quantizer_set",
35
    "compute_scale_from_amax",
36
37
38
]


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def compute_scale_from_amax(
    amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
    """Compute scale from amax value.

    Args:
        amax: Maximum absolute value of the tensor
        q_dtype: Quantization data type

    Returns:
        Scale value
    """
    fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
    if scale is None:
        scale = jnp.ones((1,))
    sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
    sf = jnp.where(amax > 0.0, sf, scale)
    sf = jnp.where(jnp.isfinite(amax), sf, scale)
    return sf


60
61
62
63
64
65
66
67
68
69
70
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
    """Base class for quantizers.

    This abstract class defines the interface for tensor quantization, providing
    methods for quantization and scale management.

    Attributes:
        q_dtype: The data type for quantized values
        scaling_mode: The scaling mode to use for quantization
71
        q_layout: The quantization axis (row-wise, column-wise, or both)
72
73
74
75
    """

    q_dtype: jnp.dtype
    scaling_mode: ScalingMode
76
    q_layout: QuantizeLayout
77
78
79
80
81
82
83
84

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = ()
85
        aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer from its flattened representation.

        Args:
            aux_data: Auxiliary data containing quantizer parameters
            children: Unused children data

        Returns:
            A reconstructed Quantizer instance
        """
        return cls(*aux_data, *children)

    def update(self, *args, **kwargs):
        """Update quantizer state (no-op in base class)."""
        del args, kwargs

    def is_2x2x(self) -> bool:
        """Check if quantizer uses both row-wise and column-wise quantization.

        Returns:
            True if using both row-wise and column-wise quantization
        """
111
        return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
112
113

    @abstractmethod
114
115
    def get_data_layout(self) -> str:
        """Get the data data_layout.
116
117

        Returns:
118
            Data data_layout in string format
119
120
121
        """

    @abstractmethod
122
    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
123
124
125
126
127
128
        """Core quantization function to be implemented by subclasses.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values, default is x.dtype
129
            flatten_axis: The quantization axis for the tensor
130
131
132
133
134

        Returns:
            A ScaledTensor1x containing the quantized data
        """

135
    def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1):
136
137
138
139
140
141
142
        """Quantize a tensor using the internal _quantize_func().

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
143
            flatten_axis: The quantization axis for the tensor
144
145
146
147
148

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
        if (is_rowwise and is_colwise) or self.is_2x2x():
149
150
151
152
            rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
            colwise_tensor = self._quantize_func(
                x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
            )
153
154
155
            return ScaledTensor2x(rowwise_tensor, colwise_tensor)

        if is_colwise:
156
157
158
            return self._quantize_func(
                x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
            )
159

160
        return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
161

162
    def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1):
163
164
165
166
167
168
169
170
171
        """Get shapes for scale tensors.

        Args:
            data_shape: Shape of the input tensor
            is_padded: Whether to use padded shapes

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
172
        return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
173
174
175
176
177
178
179
180
181
182
183
184

    def get_scale_dtype(self):
        """Get the data type for scale tensors.

        Returns:
            The data type for scale tensors
        """
        return self.scaling_mode.get_scale_dtype()


@register_pytree_node_class
@dataclass
185
186
class CurrentScaleQuantizer(Quantizer):
    """Quantizer implementation using current scaling.
187

188
    This quantizer uses current scaling mode with float32 scales
189
190
191

    Attributes:
        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
192
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
193
194
    """

195
    scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
196
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
197

198
199
    def get_data_layout(self) -> str:
        """Get the data data_layout string.
200
201

        Returns:
202
            Data data_layout in string format
203
204
205
206

        Raises:
            ValueError: If quantization axis is invalid
        """
207
208
209
210
211
212
213
214
215
216
217
218
        data_layout = "NT"
        if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
            return data_layout
        if self.q_layout == QuantizeLayout.ROWWISE:
            return data_layout[0]
        if self.q_layout == QuantizeLayout.COLWISE:
            return data_layout[1]
        raise ValueError(f"Invalid q_layout: {self.q_layout}")

    def _quantize_func(
        self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
    ) -> ScaledTensor1x:
219
220
221
222
223
224
        """Quantize function helper for delayed scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
225

226
227
228
229
230
        Returns:
            A ScaledTensor1x containing the quantized data
        """
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype

231
        compute_dtype = jnp.float32
232
        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
233
234
235
236
        amax = jnp.max(jnp.abs(x)).reshape((1,))
        fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
        scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
        scaled_x = x.astype(compute_dtype) * scale
237
238
239
240
241
242
243

        # quantize() in the old dot.py do this way, leave this code block here for future debugging
        # compute_dtype = x.dtype
        # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
        # scaled_x = x * self.scale.astype(compute_dtype)

        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
244
        scale_inv = 1.0 / scale
245
246
247
248
249
        return ScaledTensorFactory.create_1x(
            data=clipped_scaled_x,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
250
            flatten_axis=flatten_axis,
251
252
        )

253
254
255
    def quantize(
        self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
    ):
256
257
258
259
260
261
262
        """Quantize a tensor using the internal _quantize_func().

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
263
            flatten_axis: The quantization axis for the tensor
264
265
266
267
268

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
269
270
271
272
        if flatten_axis < 0:
            flatten_axis += x.ndim
        assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"

273
274
275
        is_rowwise = (
            is_rowwise
            if is_rowwise is not None
276
            else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
277
278
279
280
        )
        is_colwise = (
            is_colwise
            if is_colwise is not None
281
            else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
282
283
        )

284
        rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
285
286
287
        colwise_tensor = None
        if is_colwise:
            colwise_tensor = ScaledTensorFactory.create_1x(
288
289
290
                data=jnp.transpose(
                    rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
                ),
291
292
293
294
                scale_inv=rowwise_tensor.scale_inv,
                scaling_mode=self.scaling_mode,
                dq_dtype=dq_dtype,
                is_colwise=True,
295
296
                data_layout="T",
                flatten_axis=flatten_axis,
297
298
299
300
301
302
303
            )
        if is_colwise and is_rowwise:
            return ScaledTensor2x(rowwise_tensor, colwise_tensor)
        if is_colwise:
            return colwise_tensor
        return rowwise_tensor

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(CurrentScaleQuantizer):
    """Quantizer implementation using delayed scaling.

    This quantizer uses delayed scaling mode with float32 scales and maintains
    a history of maximum absolute values for dynamic scaling.

    Attributes:
        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
        scale: Current scaling factor
        amax_history: History of maximum absolute values
    """

    scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE

    scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
    amax_history: jnp.ndarray = field(
        default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
    )

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.scale, self.amax_history)
        aux_data = (self.q_dtype, self.scaling_mode, self.q_layout)
        return (children, aux_data)

    def _quantize_func(
        self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
    ) -> ScaledTensor1x:
        """Quantize function helper for delayed scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
            flatten_axis: The quantization axis for the tensor
        Returns:
            A ScaledTensor1x containing the quantized data
        """
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype

        compute_dtype = jnp.float32
        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
        scaled_x = x.astype(compute_dtype) * self.scale

        # quantize() in the old dot.py do this way, leave this code block here for future debugging
        # compute_dtype = x.dtype
        # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
        # scaled_x = x * self.scale.astype(compute_dtype)

        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
        scale_inv = 1.0 / self.scale
        self.update(jnp.max(jnp.abs(x)).reshape((1,)))
        return ScaledTensorFactory.create_1x(
            data=clipped_scaled_x,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
            flatten_axis=flatten_axis,
        )

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    @staticmethod
    @jax.jit
    def _update_amax_history(amax_history, new_amax):
        """Update AMAX history with new maximum value.

        Args:
            amax_history: Current AMAX history
            new_amax: New maximum value to add

        Returns:
            Updated AMAX history
        """
        amax_history = amax_history.at[0].set(new_amax[0])
        return amax_history

    @staticmethod
    @partial(jax.jit, static_argnums=(2,))
    def _compute_scale(amax_history, scale, q_dtype):
        """Compute new scale based on AMAX history.

        Args:
            amax_history: History of maximum absolute values
            scale: Current scale
            q_dtype: Quantization data type

        Returns:
            Updated scale value
        """
        # 2. Calculate the current scale
        if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
            amax = jnp.max(amax_history, axis=-1, keepdims=True)
        else:
            amax = amax_history[0:1]

407
        return compute_scale_from_amax(amax, q_dtype, scale=scale)
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

    @staticmethod
    @jax.jit
    def _roll_and_reset_amax_history(amax_history):
        """Roll AMAX history and reset first element.

        Args:
            amax_history: Current AMAX history

        Returns:
            Updated AMAX history
        """
        updated_amax_history = jnp.roll(amax_history, -1, -1)
        amax_history = updated_amax_history.at[0].set(0.0)
        return amax_history

    def update(self, new_amax: jnp.ndarray):
        """Update AMAX history and compute new scale.

        Args:
            new_amax: New maximum absolute value to add to history
        """
        amax_history = self._update_amax_history(self.amax_history, new_amax)
        self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
        self.amax_history = self._roll_and_reset_amax_history(amax_history)


@register_pytree_node_class
@dataclass
class BlockScaleQuantizer(Quantizer):
    """Quantizer implementation using block-based scaling.

    This quantizer uses block scaling mode with FP8 scales and block-based
    quantization for improved efficiency.

    Attributes:
        scaling_mode: Set to NVTE_MXFP8_1D_SCALING
445
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
446
447
    """

448
    scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
449
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
450

451
452
    def get_data_layout(self) -> str:
        """Get the data data_layout string.
453
454

        Returns:
455
            Data data_layout in string format
456
457
458
459
460
        """
        if self.is_2x2x():
            return "NN"
        return "N"

461
    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
462
463
464
465
466
467
        """Quantize function helper for block scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
468
            flatten_axis: The quantization axis for the tensor
469
470
471
472
473

        Returns:
            A ScaledTensor1x containing the quantized data
        """
        # TODO(Phuong): use quantize_func from JAX
474
475
476
477
478
479
        if flatten_axis < 0:
            flatten_axis = x.ndim + flatten_axis
        assert (
            0 <= flatten_axis < x.ndim
        ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"

480
481
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
        x_shape = x.shape
482
483
484
        scale_shape = self.scaling_mode.get_scale_shape(
            x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
        )
485
486
        scale_dtype = self.scaling_mode.get_scale_dtype()
        x = x.reshape(
487
488
489
490
            *x_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *x_shape[flatten_axis:-1],
491
492
493
            scale_shape[-1],
            int(x_shape[-1] / scale_shape[-1]),
        )
494
        amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
        scales = amax.astype(jnp.float32) / MAX

        scales_q = self._cast_to_e8m0_with_rounding_up(scales)
        scaled_x = x / self._e8m0_to_dtype(scales_q, jnp.float32)

        clipped_x = jnp.clip(scaled_x, -MAX, MAX)
        x_q = clipped_x.astype(self.q_dtype).reshape(x_shape)
        scales_q = scales_q.reshape(scale_shape).view(scale_dtype)

        return ScaledTensorFactory.create_1x(
            x_q,
            scales_q,
            self.scaling_mode,
            is_colwise=is_colwise,
            dq_dtype=dq_dtype,
511
            flatten_axis=flatten_axis,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        )

    def _cast_to_e8m0_with_rounding_up(self, scales):
        """Cast scales to E8M0 format with rounding up.

        Args:
            scales: Input scales to convert

        Returns:
            Scales in E8M0 format
        """
        temp = scales.astype(jnp.float32).view(jnp.uint32)
        exp = temp >> 23
        mant = temp & 0x7FFFFF
        is_ru = jnp.logical_and(
            jnp.logical_and((mant > 0), (exp != 0xFE)),
            ~jnp.logical_and((exp == 0), (mant <= 0x400000)),
        )
        exp = jnp.where(is_ru, exp + 1, exp)
        new_scales = exp.astype(jnp.uint8)
        return new_scales

    def _e8m0_to_dtype(self, x, dtype):
        """Convert E8M0 format to specified data type.

        Args:
            x: Input in E8M0 format
            dtype: Target data type

        Returns:
            Converted values in target data type
        """
        temp = x.astype(jnp.uint32)
        exp = temp << 23
        new_x = exp.view(jnp.float32)
        near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
        new_x = jnp.where(new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x)
        return new_x.astype(dtype)


@register_pytree_node_class
@dataclass
class QuantizerSet:
    """Set of quantizers for different tensor types.

    This class manages quantizers for input tensors, kernel tensors, and
    gradient tensors.

    Attributes:
        x: Quantizer for input tensors
        kernel: Quantizer for kernel tensors
        dgrad: Quantizer for gradient tensors
    """

    x: Optional[Quantizer]
    kernel: Optional[Quantizer]
    dgrad: Optional[Quantizer]

    def tree_flatten(self):
        """Flatten the quantizer set for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.x, self.kernel, self.dgrad)
        aux_data = ()
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer set from its flattened representation.

        Args:
            aux_data: Unused auxiliary data
            children: Tuple of quantizers

        Returns:
            A reconstructed QuantizerSet instance
        """
        return cls(*aux_data, *children)


@dataclass
class QuantizerFactory:
    """Factory class for creating quantizers.

    This class provides static methods to create individual quantizers and
    sets of quantizers with various configurations.
    """

    quantizer_type_map = {
603
        ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
604
        ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
605
        ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
606
607
608
609
610
611
612
    }

    @staticmethod
    def create(
        n_quantizers: int = 1,
        scaling_mode: ScalingMode = None,
        q_dtype: jnp.dtype = None,
613
        q_layout: QuantizeLayout = None,
614
615
616
617
618
619
620
621
        **kwargs,
    ) -> Quantizer:
        """Create one or more quantizers with specified parameters.

        Args:
            n_quantizers: Number of quantizers to create
            scaling_mode: Scaling mode to use
            q_dtype: Quantization data type
622
623
            q_layout: Quantization axis
            flatten_axis: The quantization axis for the tensor
624
625
626
627
628
629
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A single quantizer or tuple of quantizers
        """
        # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
630
631
632
        assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
        # import pdb; pdb.set_trace()
        if scaling_mode == ScalingMode.NO_SCALING:
633
634
635
636
637
638
639
            quantizers = [None] * n_quantizers
        else:
            quantizers = []
            for _ in range(n_quantizers):
                quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
                quantizers.append(
                    quantizer_type(
640
                        q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
                    )
                )
        return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)

    @staticmethod
    def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> QuantizerSet:
        """Create a set of quantizers for forward and backward passes.

        Args:
            scaling_mode: Scaling mode to use
            fwd_dtype: Data type for forward pass
            bwd_dtype: Data type for backward pass
            is_2x2x: Whether to use 2x2x quantization
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A QuantizerSet instance
        """
        if is_2x2x:
660
            q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
661
        else:
662
663
664
            q_layout_x = QuantizeLayout.ROWWISE
            q_layout_kernel = QuantizeLayout.COLWISE
            q_layout_dgrad = None
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

        if "quantize_meta_set" in kwargs:
            quantize_meta_set = kwargs.get("quantize_meta_set")
            args_x = {
                "scale": quantize_meta_set.x.scale,
                "amax_history": quantize_meta_set.x.amax_history,
            }
            args_kernel = {
                "scale": quantize_meta_set.kernel.scale,
                "amax_history": quantize_meta_set.kernel.amax_history,
            }
            args_grad = {
                "scale": quantize_meta_set.grad.scale,
                "amax_history": quantize_meta_set.grad.amax_history,
            }
        else:
            args_x = args_kernel = args_grad = {}

683
684
685
686
687
        q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x)
        q_kernel = QuantizerFactory.create(
            1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel
        )
        q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
        return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)

    @staticmethod
    def create_set(
        n_quantizer_sets: int = 1,
        scaling_mode: ScalingMode = None,
        fwd_dtype: jnp.dtype = None,
        bwd_dtype: jnp.dtype = None,
        is_2x2x: bool = None,
        **kwargs,
    ) -> tuple[Union[tuple[Quantizer], None]]:
        """Create one or more sets of quantizers.

        Args:
            n_quantizer_sets: Number of quantizer sets to create
            scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
            fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
            bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
            is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A single quantizer set or tuple of quantizer sets
        """
        scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
        fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
        bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
        is_2x2x = is_2x2x or QuantizeConfig.IF_QUANTIZE_2X

        q_set = []
        for _ in range(n_quantizer_sets):
            q_set.append(
                QuantizerFactory._create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs)
            )

        return q_set[0] if len(q_set) == 1 else tuple(q_set)


726
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING)