quantizer.py 33.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Tensor quantization classes for TE/JAX.

This module provides classes and utilities for quantizing tensors in JAX.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import partial
12
13
from typing import Union, Optional, Tuple
import warnings
14
15
16
17

import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
18
from transformer_engine_jax import QuantizeLayout
19
from transformer_engine.common import recipe
20
21

from .scaling_modes import ScalingMode
22
from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
23
24
25
from .helper import (
    QuantizeConfig,
    AmaxComputeAlgo,
26
    _get_scaling_mode,
27
)
28
from .device_utils import is_fp8_gemm_with_all_layouts_supported
29
30

__all__ = [
31
    "QuantizeLayout",
32
33
    "Quantizer",
    "QuantizerSet",
34
    "CurrentScaleQuantizer",
35
36
    "DelayedScaleQuantizer",
    "BlockScaleQuantizer",
37
    "GroupedQuantizer",
38
39
    "QuantizerFactory",
    "noop_quantizer_set",
40
    "compute_scale_from_amax",
41
42
43
]


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def compute_scale_from_amax(
    amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
    """Compute scale from amax value.

    Args:
        amax: Maximum absolute value of the tensor
        q_dtype: Quantization data type

    Returns:
        Scale value
    """
    fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
    if scale is None:
        scale = jnp.ones((1,))
    sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
    sf = jnp.where(amax > 0.0, sf, scale)
    sf = jnp.where(jnp.isfinite(amax), sf, scale)
    return sf


65
66
67
68
69
70
71
72
73
74
75
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
    """Base class for quantizers.

    This abstract class defines the interface for tensor quantization, providing
    methods for quantization and scale management.

    Attributes:
        q_dtype: The data type for quantized values
        scaling_mode: The scaling mode to use for quantization
76
        q_layout: The quantization axis (row-wise, column-wise, or both)
77
78
79
80
    """

    q_dtype: jnp.dtype
    scaling_mode: ScalingMode
81
    q_layout: QuantizeLayout
82
    data_layout: str
83
84
85
86
87
88
89
90

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = ()
91
        aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer from its flattened representation.

        Args:
            aux_data: Auxiliary data containing quantizer parameters
            children: Unused children data

        Returns:
            A reconstructed Quantizer instance
        """
        return cls(*aux_data, *children)

    def update(self, *args, **kwargs):
        """Update quantizer state (no-op in base class)."""
        del args, kwargs

    def is_2x2x(self) -> bool:
        """Check if quantizer uses both row-wise and column-wise quantization.

        Returns:
            True if using both row-wise and column-wise quantization
        """
117
        return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
118

119
    def get_data_layout(self) -> str:
120
        """Get the data data_layout string.
121
122

        Returns:
123
            Data data_layout in string format
124
125
126

        Raises:
            ValueError: If quantization axis is invalid
127
        """
128
129
130
131
132
133
134
        if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
            return self.data_layout
        if self.q_layout == QuantizeLayout.ROWWISE:
            return self.data_layout[0]
        if self.q_layout == QuantizeLayout.COLWISE:
            return self.data_layout[1]
        raise ValueError(f"Invalid q_layout: {self.q_layout}")
135
136

    @abstractmethod
137
    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
138
139
140
141
142
143
        """Core quantization function to be implemented by subclasses.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values, default is x.dtype
144
            flatten_axis: The quantization axis for the tensor
145
146
147
148
149

        Returns:
            A ScaledTensor1x containing the quantized data
        """

150
151
152
    def quantize(
        self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs
    ) -> ScaledTensor:
153
154
155
156
157
158
159
        """Quantize a tensor using the internal _quantize_func().

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
160
            flatten_axis: The quantization axis for the tensor
161
162
163
164

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
165
        del kwargs
166
        if (is_rowwise and is_colwise) or self.is_2x2x():
167
168
169
170
            rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
            colwise_tensor = self._quantize_func(
                x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
            )
171
172
173
            return ScaledTensor2x(rowwise_tensor, colwise_tensor)

        if is_colwise:
174
175
176
            return self._quantize_func(
                x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis
            )
177

178
        return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
179

180
    def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, **kwargs):
181
182
183
184
185
186
187
188
189
        """Get shapes for scale tensors.

        Args:
            data_shape: Shape of the input tensor
            is_padded: Whether to use padded shapes

        Returns:
            Tuple of (rowwise_scale_shape, colwise_scale_shape)
        """
190
        del kwargs
191
        return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
192
193
194
195
196
197
198
199
200
201
202
203

    def get_scale_dtype(self):
        """Get the data type for scale tensors.

        Returns:
            The data type for scale tensors
        """
        return self.scaling_mode.get_scale_dtype()


@register_pytree_node_class
@dataclass
204
205
class CurrentScaleQuantizer(Quantizer):
    """Quantizer implementation using current scaling.
206

207
    This quantizer uses current scaling mode with float32 scales
208
209
210

    Attributes:
        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
211
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
212
213
    """

214
    scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
215
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
216
    data_layout: str = "NT"
217
218
219
220

    def _quantize_func(
        self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
    ) -> ScaledTensor1x:
221
222
223
224
225
226
        """Quantize function helper for delayed scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
227

228
229
230
231
232
        Returns:
            A ScaledTensor1x containing the quantized data
        """
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype

233
        compute_dtype = jnp.float32
234
        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
235
        amax = jnp.max(jnp.abs(x)).reshape((1,))
236
237
238
        fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
        scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
        scaled_x = x.astype(compute_dtype) * scale
239
240

        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
241
        scale_inv = 1.0 / scale
242
243
244
245
246
        return ScaledTensorFactory.create_1x(
            data=clipped_scaled_x,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
247
            flatten_axis=flatten_axis,
248
249
        )

250
251
252
    def quantize(
        self, x, is_rowwise: bool = None, is_colwise: bool = None, dq_dtype=None, flatten_axis=-1
    ):
253
254
255
256
257
258
259
        """Quantize a tensor using the internal _quantize_func().

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
260
            flatten_axis: The quantization axis for the tensor
261
262
263
264
265

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
266
267
268
269
        if flatten_axis < 0:
            flatten_axis += x.ndim
        assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"

270
271
272
        is_rowwise = (
            is_rowwise
            if is_rowwise is not None
273
            else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
274
275
276
277
        )
        is_colwise = (
            is_colwise
            if is_colwise is not None
278
            else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
279
280
        )

281
        rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
282
283
284
        colwise_tensor = None
        if is_colwise:
            colwise_tensor = ScaledTensorFactory.create_1x(
285
286
287
                data=jnp.transpose(
                    rowwise_tensor.data, (*range(flatten_axis, x.ndim), *range(flatten_axis))
                ),
288
289
290
291
                scale_inv=rowwise_tensor.scale_inv,
                scaling_mode=self.scaling_mode,
                dq_dtype=dq_dtype,
                is_colwise=True,
292
293
                data_layout="T",
                flatten_axis=flatten_axis,
294
            )
295

296
297
298
299
300
301
        if is_colwise and is_rowwise:
            return ScaledTensor2x(rowwise_tensor, colwise_tensor)
        if is_colwise:
            return colwise_tensor
        return rowwise_tensor

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

@register_pytree_node_class
@dataclass
class DelayedScaleQuantizer(CurrentScaleQuantizer):
    """Quantizer implementation using delayed scaling.

    This quantizer uses delayed scaling mode with float32 scales and maintains
    a history of maximum absolute values for dynamic scaling.

    Attributes:
        scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
        scale: Current scaling factor
        amax_history: History of maximum absolute values
    """

    scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE

    scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32))
    amax_history: jnp.ndarray = field(
        default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32)
    )

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.scale, self.amax_history)
333
        aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        return (children, aux_data)

    def _quantize_func(
        self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
    ) -> ScaledTensor1x:
        """Quantize function helper for delayed scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
            flatten_axis: The quantization axis for the tensor
        Returns:
            A ScaledTensor1x containing the quantized data
        """
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype

        compute_dtype = jnp.float32
        dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
        scaled_x = x.astype(compute_dtype) * self.scale

        # quantize() in the old dot.py do this way, leave this code block here for future debugging
        # compute_dtype = x.dtype
        # dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
        # scaled_x = x * self.scale.astype(compute_dtype)

        clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
        scale_inv = 1.0 / self.scale
        self.update(jnp.max(jnp.abs(x)).reshape((1,)))
        return ScaledTensorFactory.create_1x(
            data=clipped_scaled_x,
            scale_inv=scale_inv,
            scaling_mode=self.scaling_mode,
            dq_dtype=dq_dtype,
            flatten_axis=flatten_axis,
        )

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    @staticmethod
    @jax.jit
    def _update_amax_history(amax_history, new_amax):
        """Update AMAX history with new maximum value.

        Args:
            amax_history: Current AMAX history
            new_amax: New maximum value to add

        Returns:
            Updated AMAX history
        """
        amax_history = amax_history.at[0].set(new_amax[0])
        return amax_history

    @staticmethod
    @partial(jax.jit, static_argnums=(2,))
    def _compute_scale(amax_history, scale, q_dtype):
        """Compute new scale based on AMAX history.

        Args:
            amax_history: History of maximum absolute values
            scale: Current scale
            q_dtype: Quantization data type

        Returns:
            Updated scale value
        """
        # 2. Calculate the current scale
        if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
            amax = jnp.max(amax_history, axis=-1, keepdims=True)
        else:
            amax = amax_history[0:1]

405
        return compute_scale_from_amax(amax, q_dtype, scale=scale)
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

    @staticmethod
    @jax.jit
    def _roll_and_reset_amax_history(amax_history):
        """Roll AMAX history and reset first element.

        Args:
            amax_history: Current AMAX history

        Returns:
            Updated AMAX history
        """
        updated_amax_history = jnp.roll(amax_history, -1, -1)
        amax_history = updated_amax_history.at[0].set(0.0)
        return amax_history

    def update(self, new_amax: jnp.ndarray):
        """Update AMAX history and compute new scale.

        Args:
            new_amax: New maximum absolute value to add to history
        """
        amax_history = self._update_amax_history(self.amax_history, new_amax)
        self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype)
        self.amax_history = self._roll_and_reset_amax_history(amax_history)


@register_pytree_node_class
@dataclass
class BlockScaleQuantizer(Quantizer):
    """Quantizer implementation using block-based scaling.

    This quantizer uses block scaling mode with FP8 scales and block-based
    quantization for improved efficiency.

    Attributes:
        scaling_mode: Set to NVTE_MXFP8_1D_SCALING
443
        q_layout: Quantization axis (default: ROWWISE_COLWISE)
444
445
    """

446
    scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
447
    q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
448
    data_layout: str = "NN"
449

450
    def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
451
452
453
454
455
456
        """Quantize function helper for block scaling FP8.

        Args:
            x: Input tensor to quantize
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
457
            flatten_axis: The quantization axis for the tensor
458
459
460
461
462

        Returns:
            A ScaledTensor1x containing the quantized data
        """
        # TODO(Phuong): use quantize_func from JAX
463
464
465
466
467
468
        if flatten_axis < 0:
            flatten_axis = x.ndim + flatten_axis
        assert (
            0 <= flatten_axis < x.ndim
        ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}"

469
470
        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
        x_shape = x.shape
471
472
473
        scale_shape = self.scaling_mode.get_scale_shape(
            x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
        )
474
475
        scale_dtype = self.scaling_mode.get_scale_dtype()
        x = x.reshape(
476
477
478
479
            *x_shape[: flatten_axis - 1],
            scale_shape[flatten_axis - 1],
            int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]),
            *x_shape[flatten_axis:-1],
480
481
482
            scale_shape[-1],
            int(x_shape[-1] / scale_shape[-1]),
        )
483
        amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True)
484
485
486
487
488
489
490
491
492
493
494
495
496
        MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32)
        scales = amax.astype(jnp.float32) / MAX

        scales_q = self._cast_to_e8m0_with_rounding_up(scales)
        scaled_x = x / self._e8m0_to_dtype(scales_q, jnp.float32)

        clipped_x = jnp.clip(scaled_x, -MAX, MAX)
        x_q = clipped_x.astype(self.q_dtype).reshape(x_shape)
        scales_q = scales_q.reshape(scale_shape).view(scale_dtype)

        return ScaledTensorFactory.create_1x(
            x_q,
            scales_q,
497
            scaling_mode=self.scaling_mode,
498
499
            is_colwise=is_colwise,
            dq_dtype=dq_dtype,
500
            flatten_axis=flatten_axis,
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        )

    def _cast_to_e8m0_with_rounding_up(self, scales):
        """Cast scales to E8M0 format with rounding up.

        Args:
            scales: Input scales to convert

        Returns:
            Scales in E8M0 format
        """
        temp = scales.astype(jnp.float32).view(jnp.uint32)
        exp = temp >> 23
        mant = temp & 0x7FFFFF
        is_ru = jnp.logical_and(
            jnp.logical_and((mant > 0), (exp != 0xFE)),
            ~jnp.logical_and((exp == 0), (mant <= 0x400000)),
        )
        exp = jnp.where(is_ru, exp + 1, exp)
        new_scales = exp.astype(jnp.uint8)
        return new_scales

    def _e8m0_to_dtype(self, x, dtype):
        """Convert E8M0 format to specified data type.

        Args:
            x: Input in E8M0 format
            dtype: Target data type

        Returns:
            Converted values in target data type
        """
        temp = x.astype(jnp.uint32)
        exp = temp << 23
        new_x = exp.view(jnp.float32)
        near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
        new_x = jnp.where(new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x)
        return new_x.astype(dtype)


@register_pytree_node_class
@dataclass
class QuantizerSet:
    """Set of quantizers for different tensor types.

    This class manages quantizers for input tensors, kernel tensors, and
    gradient tensors.

    Attributes:
        x: Quantizer for input tensors
        kernel: Quantizer for kernel tensors
        dgrad: Quantizer for gradient tensors
    """

    x: Optional[Quantizer]
    kernel: Optional[Quantizer]
    dgrad: Optional[Quantizer]

    def tree_flatten(self):
        """Flatten the quantizer set for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.x, self.kernel, self.dgrad)
        aux_data = ()
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        """Reconstruct a quantizer set from its flattened representation.

        Args:
            aux_data: Unused auxiliary data
            children: Tuple of quantizers

        Returns:
            A reconstructed QuantizerSet instance
        """
        return cls(*aux_data, *children)


583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
@register_pytree_node_class
@dataclass
class GroupedQuantizer(Quantizer):
    """Quantizer for grouped arrays.

    This class extends Quantizer to support quantization of arrays in grouped manner,
    where elements are grouped along a specified axis then quantized separately.

    Attributes:
        data_layout: The data layout specification
        n_groups: Number of groups for quantization
        quantizers: Tuple of quantizers for each group
    """

    data_layout: str = None
    n_groups: int = 1
    quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,))

    def tree_flatten(self):
        """Flatten the quantizer for JAX tree operations.

        Returns:
            Tuple of (children, aux_data) for tree operations
        """
        children = (self.quantizers,)
        aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups)
        return (children, aux_data)

    def __post_init__(self):
        if self.quantizers[0] is None:
613
            quantizers = QuantizerFactory.create(
614
615
                self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
            )
616
            self.quantizers = (quantizers,) if not isinstance(quantizers, tuple) else quantizers
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        self.data_layout = self.quantizers[0].data_layout

    def _create_grouped_tensor_from_tensor_list(
        self, tensor_list, group_sizes, original_shape, group_axis, mode
    ):
        # mode 0 = concate, mode 1 = add
        # TODO(Ming Huang): Consider to apply Enum for mode.
        assert mode in [0, 1]
        grouped_data = (
            [] if mode == 0 else jnp.zeros(tensor_list[0].data.shape, tensor_list[0].data.dtype)
        )
        grouped_scale_inv = []

        for tensor in tensor_list:
            if mode == 0:
                grouped_data.append(tensor.data.flatten())
            else:
                grouped_data += tensor.data
            grouped_scale_inv.append(tensor.scale_inv.flatten())

        grouped_data = jnp.concatenate(grouped_data) if mode == 0 else grouped_data.flatten()
        grouped_scale_inv = jnp.concatenate(grouped_scale_inv)

        return ScaledTensorFactory.create_1x(
            grouped_data,
            grouped_scale_inv,
643
644
645
646
647
            scaling_mode=self.scaling_mode,
            dq_dtype=tensor_list[0].dq_dtype,
            is_colwise=tensor_list[0].is_colwise,
            data_layout=tensor_list[0].data_layout,
            flatten_axis=tensor_list[0].flatten_axis,
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
            group_sizes=group_sizes,
            original_shape=original_shape,
            group_axis=group_axis,
        )

    def _quantize_func(self, *args, **kwargs):
        pass

    def quantize(
        self,
        x,
        is_rowwise: bool = None,
        is_colwise: bool = None,
        dq_dtype=None,
        flatten_axis=-1,
        group_sizes=None,
        group_axis=0,
    ):
        """Quantize a tensor in grouped manner.

        Expected input shape: [M, K] or [G, K, N]
        Split to x.shape[group_axis] number of groups if group_sizes is not given

        Args:
            x: Input tensor to quantize
            is_rowwise: Whether to use row-wise quantization
            is_colwise: Whether to use column-wise quantization
            dq_dtype: Data type for dequantized values
            flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
            group_sizes: Array of ints containing the size of each group (default: None)
            group_axis: The axis along which grouping is performed (default: 0)

        Returns:
            A ScaledTensor1x or ScaledTensor2x containing the quantized data
        """
        assert group_axis == 0, "Only group_axis == 0 is supported now!"

        dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
        if flatten_axis < 0:
            flatten_axis += x.ndim
        assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"

        is_rowwise = (
            is_rowwise
            if is_rowwise is not None
            else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
        )
        is_colwise = (
            is_colwise
            if is_colwise is not None
            else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
        )
        assert is_rowwise or is_colwise, "No quantization layout is specified"

        original_shape = x.shape

        if group_sizes is not None:
            assert not is_colwise, "Not yet implememted!"
            assert group_sizes.ndim == 1, (
                "GroupedQuantizer only support 1D group_sizes, got group_sizes.ndim ="
                f" {group_sizes.ndim}"
            )

            _zeros = partial(jax.lax.full_like, fill_value=0)

            x_iota = jax.lax.broadcasted_iota(group_sizes.dtype, x.shape, 0)
            group_ends = jnp.cumulative_sum(group_sizes)
            group_starts = jax.lax.concatenate(
                [_zeros(group_sizes)[:1], group_ends[:-1]],
                dimension=0,
            )
            x_zero = _zeros(x)

            tensor_list = []
            for i in range(len(group_sizes)):
                mask = jax.lax.bitwise_and(group_starts[i] <= x_iota, x_iota < group_ends[i])
                x_selected = jax.lax.select(mask, x, x_zero)
                tensor = self.quantizers[i].quantize(
                    x_selected, is_rowwise, is_colwise, dq_dtype, flatten_axis
                )
                tensor_list.append(tensor)
            combine_mode = 1  # Add
        else:
            group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)
            x = jnp.split(x, x.shape[group_axis], axis=group_axis)

            tensor_list = []
            for i in range(len(group_sizes)):
                tensor = self.quantizers[i].quantize(
                    x[i], is_rowwise, is_colwise, dq_dtype, flatten_axis
                )
                tensor_list.append(tensor)
            combine_mode = 0  # Concate

        grouped_rowwise_tensor = grouped_colwise_tensor = None
        if is_rowwise:
            rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list]
            grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list(
                rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
            )
        if is_colwise:
            colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list]
            grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list(
                colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
            )

        if is_colwise and is_rowwise:
            return ScaledTensor2x(grouped_rowwise_tensor, grouped_colwise_tensor)
        if is_colwise:
            return grouped_colwise_tensor
        return grouped_rowwise_tensor

    def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, group_sizes=None):
        assert group_sizes, "Empty group_sizes was given!"
        return self.scaling_mode.get_grouped_scale_shape_2x(
            data_shape, group_sizes, is_padded, flatten_axis
        )


767
768
769
770
771
772
773
774
775
@dataclass
class QuantizerFactory:
    """Factory class for creating quantizers.

    This class provides static methods to create individual quantizers and
    sets of quantizers with various configurations.
    """

    quantizer_type_map = {
776
        ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer,
777
        ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer,
778
        ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer,
779
780
781
782
783
784
785
    }

    @staticmethod
    def create(
        n_quantizers: int = 1,
        scaling_mode: ScalingMode = None,
        q_dtype: jnp.dtype = None,
786
        q_layout: QuantizeLayout = None,
787
        n_groups: int = None,
788
789
790
791
792
793
794
795
        **kwargs,
    ) -> Quantizer:
        """Create one or more quantizers with specified parameters.

        Args:
            n_quantizers: Number of quantizers to create
            scaling_mode: Scaling mode to use
            q_dtype: Quantization data type
796
797
            q_layout: Quantization axis
            flatten_axis: The quantization axis for the tensor
798
            n_groups: Number of quantizers if GroupedQuantizer
799
800
801
802
803
804
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A single quantizer or tuple of quantizers
        """
        # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
805
        assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
806
807
808
809
810
811
812
813
814
815
        if n_groups:
            if n_quantizers != 1:
                warnings.warn(
                    "Using more than one GroupedQuantizer for a grouped input is not recommended"
                )
            quantizer_type = GroupedQuantizer
            kwargs["n_groups"] = n_groups
        else:
            quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)

816
        if scaling_mode == ScalingMode.NO_SCALING:
817
818
819
820
821
822
            quantizers = [None] * n_quantizers
        else:
            quantizers = []
            for _ in range(n_quantizers):
                quantizers.append(
                    quantizer_type(
823
                        q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
824
825
826
827
828
                    )
                )
        return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)

    @staticmethod
829
830
831
    def _create_set(
        scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
    ) -> QuantizerSet:
832
833
834
835
836
837
838
        """Create a set of quantizers for forward and backward passes.

        Args:
            scaling_mode: Scaling mode to use
            fwd_dtype: Data type for forward pass
            bwd_dtype: Data type for backward pass
            is_2x2x: Whether to use 2x2x quantization
839
            n_groups
840
841
842
843
844
845
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A QuantizerSet instance
        """
        if is_2x2x:
846
            q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE
847
        else:
848
849
850
851
852
            q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE
            if scaling_mode.is_1d_block_scaling():
                q_layout_kernel = QuantizeLayout.COLWISE
            if QuantizeConfig.INFERENCE_MODE:
                q_layout_dgrad = None
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870

        if "quantize_meta_set" in kwargs:
            quantize_meta_set = kwargs.get("quantize_meta_set")
            args_x = {
                "scale": quantize_meta_set.x.scale,
                "amax_history": quantize_meta_set.x.amax_history,
            }
            args_kernel = {
                "scale": quantize_meta_set.kernel.scale,
                "amax_history": quantize_meta_set.kernel.amax_history,
            }
            args_grad = {
                "scale": quantize_meta_set.grad.scale,
                "amax_history": quantize_meta_set.grad.amax_history,
            }
        else:
            args_x = args_kernel = args_grad = {}

871
        q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
872
        q_kernel = QuantizerFactory.create(
873
874
875
876
            1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
        )
        q_dgrad = QuantizerFactory.create(
            1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
877
        )
878
879
880
881
882
        return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)

    @staticmethod
    def create_set(
        n_quantizer_sets: int = 1,
883
        scaling_mode: Optional[ScalingMode] = None,
884
885
886
        fwd_dtype: jnp.dtype = None,
        bwd_dtype: jnp.dtype = None,
        is_2x2x: bool = None,
887
        n_groups: int = None,
888
        fp8_recipe: Optional[recipe.Recipe] = None,
889
890
891
892
893
894
895
896
897
898
        **kwargs,
    ) -> tuple[Union[tuple[Quantizer], None]]:
        """Create one or more sets of quantizers.

        Args:
            n_quantizer_sets: Number of quantizer sets to create
            scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE
            fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
            bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
            is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
899
            n_groups:
900
            fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
901
902
903
904
905
            **kwargs: Additional arguments for quantizer initialization

        Returns:
            A single quantizer set or tuple of quantizer sets
        """
906
907
908
909
910
911
912
913
914
915
916
917
918

        assert scaling_mode is None or fp8_recipe is None, (
            "Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling"
            " mode can be specified directly via the scaling_mode parameter or indirectly via"
            " recipe. Recipe is preferred as it will support additional recipes in future where"
            " scaling mode differs between x, kernel, and grad in the quantizer set."
        )

        if fp8_recipe is not None:
            # TODO(jberchtold): once recipe and scaling mode are decoupled update this logic
            scaling_mode = _get_scaling_mode(fp8_recipe)
        else:
            scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE
919
920
        fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE
        bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE
921
922
923
924
925
926
927
928
929
        if is_2x2x is None:
            if scaling_mode.is_1d_block_scaling():
                is_2x2x = True
            elif scaling_mode.is_tensor_scaling():
                is_2x2x = not is_fp8_gemm_with_all_layouts_supported()
            else:  # NO_SCALING ignores is_2x2x for now
                is_2x2x = False
        is_inference_mode = QuantizeConfig.INFERENCE_MODE
        assert not is_inference_mode, "Inference mode is not supported yet!"
930
931
932
933

        q_set = []
        for _ in range(n_quantizer_sets):
            q_set.append(
934
935
936
                QuantizerFactory._create_set(
                    scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
                )
937
938
939
940
941
            )

        return q_set[0] if len(q_set) == 1 else tuple(q_set)


942
noop_quantizer_set = QuantizerFactory.create_set(scaling_mode=ScalingMode.NO_SCALING, is_2x2x=False)