quantization.py 43.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9

10

11
import jax
12
import jax.numpy as jnp
13
from jax import dtypes, ffi
14
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18

19
from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax
20
21
22
23
from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
24
    te_dtype_to_jax_dtype,
25
    jax_dtype_to_te_dtype,
26
27
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
28
    get_min_device_compute_capability,
29
    NamedSharding,
30
)
31
32
33
from ..sharding import (
    all_reduce_max_along_all_axes_except_PP,
    all_reduce_sum_along_dp_fsdp,
34
    get_num_devices_in_mesh,
35
)
36
from ..quantize import (
37
38
39
40
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
41
    Quantizer,
42
    GroupedQuantizer,
43
44
45
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
46
    NoScaleTensor,
47
    get_rht_matrix,
48
)
49
50


51
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
52
53


54
class BaseDBiasQuantizePrimitive(BasePrimitive):
55
    """
56
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
57
58
    """

59
    name = "te_dbias_quantize_ffi"
60
    multiple_results = True
61
    impl_static_args = (
62
63
64
65
66
67
68
69
70
71
        6,  # out_dtype
        7,  # scaling_mode
        8,  # q_layout
        9,  # flatten_axis
        10,  # scale_dtype
        11,  # is_dbias
        12,  # is_outer
        13,  # stochastic_rounding
        14,  # use_rht
    )
72
73
74
75
    inner_primitive = None
    outer_primitive = None

    @staticmethod
76
77
78
    def abstract(
        x_aval,
        scale_aval,
79
        amax_aval,
80
81
82
        sr_rng_state_aval,
        post_rht_amax_aval,
        rht_matrix_aval,
83
84
85
        *,
        out_dtype,
        scaling_mode,
86
87
        q_layout,
        flatten_axis,
88
89
90
        scale_dtype,
        is_dbias,
        is_outer,
91
92
        stochastic_rounding,
        use_rht,
93
    ):
94
        """
95
        te_dbias_quantize_p abstract
96
97
98
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
99
        out_shape = x_aval.shape
100
        assert scale_aval is None or scale_aval.dtype == jnp.float32
101
102
103
104
105
106
107
108
109
        if stochastic_rounding:
            assert ScalingMode(
                scaling_mode
            ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes"
            # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64
            assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, (
                "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
                f" received {sr_rng_state_aval}"
            )
110
            if is_outer and get_num_devices_in_mesh() > 1:
111
                assert (
112
                    sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
113
114
115
116
117
118
                    and sr_rng_state_aval.shape[1] == 4
                ), (
                    "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
                    f" True and is_outer is True but received {sr_rng_state_aval.shape}"
                )
            else:
119
120
121
                # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
                assert sr_rng_state_aval.size >= 4, (
                    "Sharded sr_rng_state must have at least 4 elements per device when"
122
123
                    f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
                )
124

125
        if QuantizeLayout(q_layout).has_rowwise:
126
127
128
129
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
130

131
132
133
134
135
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

136
        updated_amax_aval = amax_aval
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if use_rht:
            assert (
                x_aval.dtype == jnp.bfloat16
            ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion."

            if flatten_axis < 0:
                flatten_axis += len(x_aval.shape)
            rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1)
            cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
            assert rows % 64 == 0 and cols % 128 == 0, (
                "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be"
                f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape"
                f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}."
            )

            assert (
                rht_matrix_aval is not None
                and rht_matrix_aval.dtype == jnp.bfloat16
                and rht_matrix_aval.shape == (16, 16)
            ), "rht_matrix must be of shape (16, 16) and dtype bfloat16"
            assert (
                post_rht_amax_aval is not None
                and post_rht_amax_aval.dtype == jnp.float32
                and post_rht_amax_aval.size == 1
            ), "post_rht_amax must be of dtype float32"

164
165
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
166
167
168
169
170
171
        ).get_scale_shape_2x(
            x_aval.shape,
            is_padded=not is_outer,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=True,
        )
172

173
        if QuantizeLayout(q_layout).has_colwise:
174
            if ScalingMode(scaling_mode).is_colwise_transposed:
175
176
177
178
179
180
181
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
182
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
183
184
185
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
186
187

        if is_dbias:
188
189
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
190
191
192
193
194
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
195
                jax_dtype_to_te_dtype(scale_dtype),
196
                scaling_mode,
197
                q_layout.value,
198
            )
199
200
201
202
203
204
205
206
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
207
208
209
210
211
212
213
214
215
216

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
217
218

    @staticmethod
219
    def outer_abstract(*args, **kwargs):
220
        """
221
        te_dbias_quantize_p outer primitive abstract
222
        """
223
224
225
226
227
228
229
230
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
231
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
232
233
234
235
236
237
238
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
239
        amax,
240
241
242
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
243
244
245
        *,
        out_dtype,
        scaling_mode,
246
247
        q_layout,
        flatten_axis,
248
249
250
        scale_dtype,
        is_dbias,
        is_outer,
251
252
        stochastic_rounding,
        use_rht,
253
254
255
256
    ):
        """
        te_dbias_quantize_p lowering rules
        """
257
        del out_dtype, scale_dtype, is_outer
258
        x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in
259
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
260
261
262
263
264
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
265
266
267
            ctx,
            x,
            scale,
268
            amax,
269
270
271
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
272
            scaling_mode=scaling_mode.value,
273
            q_layout=q_layout.value.value,
274
            flatten_axis=flatten_axis,
275
            is_dbias=is_dbias,
276
277
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
278
        )
279
280

    @staticmethod
281
282
283
    def impl(
        x,
        scale,
284
        amax,
285
286
287
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
288
289
        out_dtype,
        scaling_mode,
290
291
        q_layout,
        flatten_axis,
292
293
294
        scale_dtype,
        is_dbias,
        is_outer,
295
296
        stochastic_rounding,
        use_rht,
297
    ):
298
        """
299
        te_dbias_quantize_p implementation
300
        """
301
        del is_outer
302
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
303
304
305
306
307
308
309
310
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
311
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
312
313
            x,
            scale,
314
            amax,
315
316
317
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
318
319
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
320
321
            q_layout=q_layout,
            flatten_axis=flatten_axis,
322
323
324
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
325
326
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
327
        )
328
329
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
330
331
332
        ).get_scale_shape_2x(
            x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True
        )
333
334
335
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
336
        if q_layout.has_colwise:
337
338
339
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
340
341
342
343
344
345
346
347
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
348
349

    @staticmethod
350
351
352
353
354
355
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
356
357
        q_layout,
        flatten_axis,
358
359
360
        scale_dtype,
        is_dbias,
        is_outer,
361
362
        stochastic_rounding,
        use_rht,
363
364
365
366
367
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
368
        check_valid_batch_dims(batch_dims)
369
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
370
371
        x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args
        x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims
372

373
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
374
        return (
375
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
376
377
                x,
                scale,
378
                amax,
379
380
381
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
382
383
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
384
385
                q_layout=q_layout,
                flatten_axis=flatten_axis,
386
387
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
388
389
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
390
            ),
391
392
            out_bdims,
        )
393
394

    @staticmethod
395
396
397
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
398
399
        q_layout,
        flatten_axis,
400
401
402
        scale_dtype,
        is_dbias,
        is_outer,
403
404
        stochastic_rounding,
        use_rht,
405
406
407
408
        mesh,
        arg_infos,
        result_infos,
    ):
409
410
411
412
413
414
415
416
        del (
            out_dtype,
            result_infos,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
        )  # Unused.
417

418
        x_spec = get_padded_spec(arg_infos[0])
419
        amax_spec = get_padded_spec(arg_infos[2])
420
421
        out_sharding = NamedSharding(
            mesh,
422
            PartitionSpec(*x_spec),
423
            desc="BaseDBiasQuantizePrimitive.out_sharding",
424
        )
425
        if q_layout.has_colwise:
426
            if ScalingMode(scaling_mode).is_colwise_transposed:
427
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
428
429
430
431
432
433
434
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
435
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
436
        )
437
438
439

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
440
            mesh,
441
            PartitionSpec(*dbias_spec),
442
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
443
        )
444

445
        scale_inv_spec = colwise_scale_inv_spec = (None,)
446
        if ScalingMode(scaling_mode).is_block_scaling:
447
448
            scale_inv_spec = x_spec

449
        if q_layout.has_colwise:
450
451
452
453
454
455
456
457
458
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
459
460

        scale_inv_sharding = NamedSharding(
461
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
462
        )
463
        colwise_scale_inv_sharding = NamedSharding(
464
            mesh,
465
            PartitionSpec(*colwise_scale_inv_spec),
466
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
467
        )
468
469
470
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
471

472
473
474
475
476
477
478
479
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
480
481

    @staticmethod
482
483
484
    def partition(
        out_dtype,
        scaling_mode,
485
486
        q_layout,
        flatten_axis,
487
488
489
        scale_dtype,
        is_dbias,
        is_outer,
490
491
        stochastic_rounding,
        use_rht,
492
493
494
495
        mesh,
        arg_infos,
        result_infos,
    ):
496
        del result_infos, is_outer  # Unused.
497

498
        x_spec = get_padded_spec(arg_infos[0])
499
        amax_spec = get_padded_spec(arg_infos[2])
500
501
        out_sharding = NamedSharding(
            mesh,
502
            PartitionSpec(*x_spec),
503
            desc="BaseDBiasQuantizePrimitive.out_sharding",
504
        )
505

506
        if q_layout.has_colwise:
507
            if ScalingMode(scaling_mode).is_colwise_transposed:
508
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
509
510
511
512
513
514
515
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
516
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
517
        )
518
519
520

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
521
            mesh,
522
            PartitionSpec(*dbias_spec),
523
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
524
        )
525

526
        scale_inv_spec = colwise_scale_inv_spec = (None,)
527
        if ScalingMode(scaling_mode).is_block_scaling:
528
529
            scale_inv_spec = x_spec

530
        if q_layout.has_colwise:
531
532
533
534
535
536
537
538
539
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
540
541

        scale_inv_sharding = NamedSharding(
542
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
543
        )
544
        amax_sharding = NamedSharding(
545
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
546
        )
547
        colwise_scale_inv_sharding = NamedSharding(
548
            mesh,
549
            PartitionSpec(*colwise_scale_inv_spec),
550
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
551
        )
552

553
554
555
556
557
558
559
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        arg_shardings[3] = NamedSharding(
            mesh,
            PartitionSpec(tuple(x for x in x_spec if x is not None), None),
            desc="BaseDBiasQuantizePrimitive.sr_rng_state",
        )
        arg_shardings = tuple(arg_shardings)
560
561
562
563
564
565
566
567
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
568

569
        def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
570
571
572
            if sr_rng_state.size > 4:
                # See comment in abstract method for explanation of why we cannot assert exact shape
                sr_rng_state = sr_rng_state.flatten()[:4]
573
574
575
576
577
578
579
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
580
            ) = BaseDBiasQuantizePrimitive.impl(
581
582
                x,
                scale,
583
                amax,
584
585
586
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
587
588
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
589
590
                q_layout=q_layout,
                flatten_axis=flatten_axis,
591
592
593
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
594
595
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
596
            )
597

598
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
616
617
618

        return mesh, sharded_impl, out_shardings, arg_shardings

619
620
621
622
623
624
625
626
627
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
628
629
        stochastic_rounding,
        use_rht,
630
631
632
633
        mesh,
        value_types,
        result_types,
    ):
634
635
636
637
638
639
640
641
642
        del (
            out_dtype,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
            mesh,
            result_types,
        )
643

644
        prefix = "DBiasQuantize"
645
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
646
            value_types[0].shape,
647
            unique_var=prefix,
648
            flatten_axis=flatten_axis,
649
            q_layout=q_layout,
650
            broadcast_2d_scale_shape_to_1d=True,
651
652
        )

653
654
655
656
657
658
659
660
        input_spec = scale_rules.input_spec
        dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
        amax = (BATCHING + prefix + "_amax",)
        scale = (BATCHING + prefix + "_scale",)
        sr_rng_state = (
            BATCHING + prefix + "_sr_rng_state_partition_axis",
            BATCHING + prefix + "sr_rng_state_data_axis",
        )
661

662
663
        post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
        rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
664
665

        return SdyShardingRule(
666
667
668
669
670
671
672
673
674
            (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix),
            (
                scale_rules.rowwise_out_spec,
                scale_rules.colwise_out_spec,
                scale_rules.rowwise_scale_spec,
                scale_rules.colwise_scale_spec,
                amax,
                dbias,
            ),
675
            **scale_rules.factor_sizes,
676
677
        )

678

679
680
681
682
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
683
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
684
685
686


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
687
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
688
689


690
691
692
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
693
    if quantizer is None:
694
695
696
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
697
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
698
699


700
701
702
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
703
704
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
705
    dtype = dtype or dx.dtype
706
    dbias = jnp.sum(
707
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
708
        axis=tuple(range(sum_axis)),
709
710
        keepdims=False,
    )
711
    return dbias.astype(dtype)
712
713
714
715
716
717


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
718
    flatten_axis: int = -1,
719
720
):
    if quantizer is None:
721
722
723
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
724
725
726
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
727
728
729
    )


730
def _quantize_dbias_impl(
731
    x: Union[jnp.ndarray, NoScaleTensor],
732
733
734
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
735
    flatten_axis: int = -1,
736
    amax_scope: AmaxScope = AmaxScope.LOCAL,  # Only works when using current-scaling
737
    transpose_batch_sequence: bool = False,
738
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
739
740
741
742
    """
    Cast wrapper
    Return FP8 tensor
    """
743
744
745
746
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

747
748
749
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
750
    # Early-exit for non-quantized call
751
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
752
753
    if quantizer is None:
        dbias = None
754
        if is_dbias:
755
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
756
        return x, dbias
757

Alp Dener's avatar
Alp Dener committed
758
759
760
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
761
    is_unsupported = quantizer.q_layout.is_colwise_only and not (
762
763
764
        quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
        and hasattr(quantizer, "use_rht")
        and quantizer.use_rht
765
766
    )
    if is_unsupported or not PrimitiveClass.enabled():
767
768
769
770
771
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
772
                flatten_axis=flatten_axis,
773
            )
774
775
776
777
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
778

Alp Dener's avatar
Alp Dener committed
779
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
780
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
781
        out, _ = _quantize_dbias_impl(
782
783
784
785
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
786
            flatten_axis=flatten_axis,
787
            amax_scope=amax_scope,
788
            transpose_batch_sequence=transpose_batch_sequence,
789
        )
790
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
791
792
        return out, dbias

793
794
    use_rht = False

795
    scale = jnp.empty((1,), jnp.float32)
796
797
798
799
    post_rht_amax = None
    rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
    amax = x.amax

800
    if hasattr(quantizer, "use_rht") and quantizer.use_rht:
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
        use_rht = True
        rht_matrix = get_rht_matrix()

        new_amax, post_rht_amax = calculate_post_rht_amax(
            x.data,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            produce_regular_amax=amax is None,
            flatten_axis=flatten_axis,
        )
        if amax is None:
            # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel
            # So here we only calculate and update amax when it is not provided from a previous layer (amax is None)
            amax = new_amax

816
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
817
        if amax is None:
818
            amax = calculate_amax(
819
820
                x.data,
                amax_scope=amax_scope,
821
                transpose_batch_sequence=transpose_batch_sequence,
822
            )
823
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
824
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
825
        scale = quantizer.scale
826
827
828
829
830
831
832
833
834
        # Make sure to reset amax to zeros for DelayedScaling
        amax = jnp.zeros((1,), jnp.float32)
    elif quantizer.scaling_mode.is_nvfp4_scaling:
        if amax is None:
            amax = calculate_amax(
                x.data,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
            )
835

836
    # Make sure amax is not None
837
838
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
839

840
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
841
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
842
843
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
844
        and quantizer.q_layout.is_rowwise_colwise
845
846
        and is_1x_kernel_supported
    )
847
    q_layout = quantizer.q_layout
848

849
850
851
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

852
853
854
855
856
857
    sr_rng_state = None
    if quantizer.scaling_mode.is_nvfp4_scaling:
        # Only NVFP4 scaling modes support stochastic rounding
        if quantizer.stochastic_rounding_rng_state is not None:
            sr_rng_state = quantizer.stochastic_rounding_rng_state

858
859
860
861
862
863
864
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
865
    ) = PrimitiveClass.outer_primitive.bind(
866
        x.data,
867
        scale,
868
        amax,
869
870
871
872
873
        (
            sr_rng_state
            if sr_rng_state is not None
            else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
        ),
874
875
        post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
        rht_matrix,
876
877
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
878
        q_layout=q_layout,
879
        flatten_axis=flatten_axis,
880
        scale_dtype=quantizer.get_scale_dtype(),
881
        is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
882
        is_outer=True,
883
884
        stochastic_rounding=sr_rng_state is not None,
        use_rht=use_rht,
885
886
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
887
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
888
889
        colwise_scale_inv = rowwise_scale_inv

890
        if q_layout.is_rowwise_only:
891
892
893
894
895
896
897
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )
898
    quantizer.update(updated_amax)
899
900
    if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias:
        dbias = _jax_dbias(x, flatten_axis=flatten_axis)
901
902
903
904
905
906

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
907
908
        amax=updated_amax,
        colwise_amax=post_rht_amax,
909
        scaling_mode=quantizer.scaling_mode,
910
911
912
913
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
914
        colwise_has_rht_applied=use_rht,
915
    )
916
    return out, dbias.astype(dq_dtype)
917
918
919


def quantize(
920
    x: Union[jnp.ndarray, NoScaleTensor],
921
    quantizer: Quantizer,
922
    flatten_axis: int = -1,
923
    amax_scope: AmaxScope = AmaxScope.LOCAL,
924
    transpose_batch_sequence: bool = False,
925
926
927
928
929
930
931
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
932
933
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
934
            is None.
935
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
936
937
938
939

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
940
    out, _ = _quantize_dbias_impl(
941
942
        x,
        quantizer=quantizer,
943
        flatten_axis=flatten_axis,
944
        amax_scope=amax_scope,
945
        transpose_batch_sequence=transpose_batch_sequence,
946
947
948
949
950
    )
    return out


def quantize_dbias(
951
    dz: Union[jnp.ndarray, NoScaleTensor],
952
953
    quantizer: Quantizer,
    is_dbias: bool = True,
954
    flatten_axis: int = -1,
955
    amax_scope: AmaxScope = AmaxScope.LOCAL,
956
    transpose_batch_sequence: bool = False,
957
958
959
960
961
962
963
964
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
965
966
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
967
968
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.

969
970
971
972
973
974
975
976

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
977
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
978
979
980
981
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
982
        amax_scope=amax_scope,
983
        transpose_batch_sequence=transpose_batch_sequence,
984
    )
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

1027
1028
1029
1030
1031
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

1042
        if q_layout.has_rowwise:
1043
1044
1045
1046
1047
1048
1049
1050
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

1051
        if q_layout.has_colwise:
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
1084
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
1116
            q_layout=q_layout.value.value,
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
1163
    amax: jnp.ndarray = None,
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
1176
        amax: The amax of x; if None, it is auto-generated. (default: None)
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
1191
1192
1193
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
1221
1222
1223
1224
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
1225
1226
1227
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
1239
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
1253
        q_layout=q_layout,
1254
1255
1256
1257
1258
1259
1260
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
1261
    if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war:
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1300
1301
1302
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1303
1304
1305
1306
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias