quantization.py 43.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9

10

11
import jax
12
import jax.numpy as jnp
13
from jax import dtypes, ffi
14
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18

19
from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax
20
21
22
23
from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
24
    te_dtype_to_jax_dtype,
25
    jax_dtype_to_te_dtype,
26
27
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
28
    get_min_device_compute_capability,
29
    NamedSharding,
30
)
31
32
33
from ..sharding import (
    all_reduce_max_along_all_axes_except_PP,
    all_reduce_sum_along_dp_fsdp,
34
    get_num_devices_in_mesh,
35
)
36
from ..quantize import (
37
38
39
40
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
41
    Quantizer,
42
    GroupedQuantizer,
43
44
45
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
46
    NoScaleTensor,
47
    get_rht_matrix,
48
)
49
50


51
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
52
53


54
class BaseDBiasQuantizePrimitive(BasePrimitive):
55
    """
56
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
57
58
    """

59
    name = "te_dbias_quantize_ffi"
60
    multiple_results = True
61
    impl_static_args = (
62
63
64
65
66
67
68
69
70
71
        6,  # out_dtype
        7,  # scaling_mode
        8,  # q_layout
        9,  # flatten_axis
        10,  # scale_dtype
        11,  # is_dbias
        12,  # is_outer
        13,  # stochastic_rounding
        14,  # use_rht
    )
72
73
74
75
    inner_primitive = None
    outer_primitive = None

    @staticmethod
76
77
78
    def abstract(
        x_aval,
        scale_aval,
79
        amax_aval,
80
81
82
        sr_rng_state_aval,
        post_rht_amax_aval,
        rht_matrix_aval,
83
84
85
        *,
        out_dtype,
        scaling_mode,
86
87
        q_layout,
        flatten_axis,
88
89
90
        scale_dtype,
        is_dbias,
        is_outer,
91
92
        stochastic_rounding,
        use_rht,
93
    ):
94
        """
95
        te_dbias_quantize_p abstract
96
97
98
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
99
        out_shape = x_aval.shape
100
        assert scale_aval is None or scale_aval.dtype == jnp.float32
101
102
103
104
105
106
107
108
109
        if stochastic_rounding:
            assert ScalingMode(
                scaling_mode
            ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes"
            # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64
            assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, (
                "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
                f" received {sr_rng_state_aval}"
            )
110
            if is_outer and get_num_devices_in_mesh() > 1:
111
                assert (
112
                    sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
113
114
115
116
117
118
                    and sr_rng_state_aval.shape[1] == 4
                ), (
                    "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
                    f" True and is_outer is True but received {sr_rng_state_aval.shape}"
                )
            else:
119
120
121
                # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
                assert sr_rng_state_aval.size >= 4, (
                    "Sharded sr_rng_state must have at least 4 elements per device when"
122
123
                    f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
                )
124

125
        if QuantizeLayout(q_layout).has_rowwise:
126
127
128
129
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
130

131
132
133
134
135
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

136
        updated_amax_aval = amax_aval
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if use_rht:
            assert (
                x_aval.dtype == jnp.bfloat16
            ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion."

            if flatten_axis < 0:
                flatten_axis += len(x_aval.shape)
            rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1)
            cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
            assert rows % 64 == 0 and cols % 128 == 0, (
                "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be"
                f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape"
                f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}."
            )

            assert (
                rht_matrix_aval is not None
                and rht_matrix_aval.dtype == jnp.bfloat16
                and rht_matrix_aval.shape == (16, 16)
            ), "rht_matrix must be of shape (16, 16) and dtype bfloat16"
            assert (
                post_rht_amax_aval is not None
                and post_rht_amax_aval.dtype == jnp.float32
                and post_rht_amax_aval.size == 1
            ), "post_rht_amax must be of dtype float32"

164
165
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
166
167
168
169
170
171
        ).get_scale_shape_2x(
            x_aval.shape,
            is_padded=not is_outer,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=True,
        )
172

173
        if QuantizeLayout(q_layout).has_colwise:
174
            if ScalingMode(scaling_mode).is_colwise_transposed:
175
176
177
178
179
180
181
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
182
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
183
184
185
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
186
187

        if is_dbias:
188
189
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
190
191
192
193
194
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
195
                jax_dtype_to_te_dtype(scale_dtype),
196
                scaling_mode,
197
                q_layout.value,
198
            )
199
200
201
202
203
204
205
206
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
207
208
209
210
211
212
213
214
215
216

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
217
218

    @staticmethod
219
    def outer_abstract(*args, **kwargs):
220
        """
221
        te_dbias_quantize_p outer primitive abstract
222
        """
223
224
225
226
227
228
229
230
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
231
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
232
233
234
235
236
237
238
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
239
        amax,
240
241
242
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
243
244
245
        *,
        out_dtype,
        scaling_mode,
246
247
        q_layout,
        flatten_axis,
248
249
250
        scale_dtype,
        is_dbias,
        is_outer,
251
252
        stochastic_rounding,
        use_rht,
253
254
255
256
    ):
        """
        te_dbias_quantize_p lowering rules
        """
257
        del out_dtype, scale_dtype, is_outer
258
        x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in
259
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
260
261
262
263
264
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
265
266
267
            ctx,
            x,
            scale,
268
            amax,
269
270
271
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
272
            scaling_mode=scaling_mode.value,
273
            q_layout=q_layout.value.value,
274
            flatten_axis=flatten_axis,
275
            is_dbias=is_dbias,
276
277
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
278
        )
279
280

    @staticmethod
281
282
283
    def impl(
        x,
        scale,
284
        amax,
285
286
287
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
288
289
        out_dtype,
        scaling_mode,
290
291
        q_layout,
        flatten_axis,
292
293
294
        scale_dtype,
        is_dbias,
        is_outer,
295
296
        stochastic_rounding,
        use_rht,
297
    ):
298
        """
299
        te_dbias_quantize_p implementation
300
        """
301
        del is_outer
302
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
303
304
305
306
307
308
309
310
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
311
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
312
313
            x,
            scale,
314
            amax,
315
316
317
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
318
319
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
320
321
            q_layout=q_layout,
            flatten_axis=flatten_axis,
322
323
324
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
325
326
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
327
        )
328
329
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
330
331
332
        ).get_scale_shape_2x(
            x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True
        )
333
334
335
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
336
        if q_layout.has_colwise:
337
338
339
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
340
341
342
343
344
345
346
347
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
348
349

    @staticmethod
350
351
352
353
354
355
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
356
357
        q_layout,
        flatten_axis,
358
359
360
        scale_dtype,
        is_dbias,
        is_outer,
361
362
        stochastic_rounding,
        use_rht,
363
364
365
366
367
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
368
        check_valid_batch_dims(batch_dims)
369
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
370
371
        x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args
        x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims
372

373
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
374
        return (
375
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
376
377
                x,
                scale,
378
                amax,
379
380
381
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
382
383
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
384
385
                q_layout=q_layout,
                flatten_axis=flatten_axis,
386
387
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
388
389
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
390
            ),
391
392
            out_bdims,
        )
393
394

    @staticmethod
395
396
397
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
398
399
        q_layout,
        flatten_axis,
400
401
402
        scale_dtype,
        is_dbias,
        is_outer,
403
404
        stochastic_rounding,
        use_rht,
405
406
407
408
        mesh,
        arg_infos,
        result_infos,
    ):
409
410
411
412
413
414
415
416
        del (
            out_dtype,
            result_infos,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
        )  # Unused.
417

418
        x_spec = get_padded_spec(arg_infos[0])
419
        amax_spec = get_padded_spec(arg_infos[2])
420
421
        out_sharding = NamedSharding(
            mesh,
422
            PartitionSpec(*x_spec),
423
            desc="BaseDBiasQuantizePrimitive.out_sharding",
424
        )
425
        if q_layout.has_colwise:
426
            if ScalingMode(scaling_mode).is_colwise_transposed:
427
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
428
429
430
431
432
433
434
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
435
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
436
        )
437
438
439

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
440
            mesh,
441
            PartitionSpec(*dbias_spec),
442
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
443
        )
444

445
        scale_inv_spec = colwise_scale_inv_spec = (None,)
446
        if ScalingMode(scaling_mode).is_block_scaling:
447
448
            scale_inv_spec = x_spec

449
        if q_layout.has_colwise:
450
451
452
453
454
455
456
457
458
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
459
460

        scale_inv_sharding = NamedSharding(
461
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
462
        )
463
        colwise_scale_inv_sharding = NamedSharding(
464
            mesh,
465
            PartitionSpec(*colwise_scale_inv_spec),
466
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
467
        )
468
469
470
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
471

472
473
474
475
476
477
478
479
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
480
481

    @staticmethod
482
483
484
    def partition(
        out_dtype,
        scaling_mode,
485
486
        q_layout,
        flatten_axis,
487
488
489
        scale_dtype,
        is_dbias,
        is_outer,
490
491
        stochastic_rounding,
        use_rht,
492
493
494
495
        mesh,
        arg_infos,
        result_infos,
    ):
496
        del result_infos, is_outer  # Unused.
497

498
        x_spec = get_padded_spec(arg_infos[0])
499
        amax_spec = get_padded_spec(arg_infos[2])
500
        sr_rng_state_spec = get_padded_spec(arg_infos[3])
501
502
        out_sharding = NamedSharding(
            mesh,
503
            PartitionSpec(*x_spec),
504
            desc="BaseDBiasQuantizePrimitive.out_sharding",
505
        )
506

507
        if q_layout.has_colwise:
508
            if ScalingMode(scaling_mode).is_colwise_transposed:
509
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
510
511
512
513
514
515
516
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
517
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
518
        )
519
520
521

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
522
            mesh,
523
            PartitionSpec(*dbias_spec),
524
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
525
        )
526

527
        scale_inv_spec = colwise_scale_inv_spec = (None,)
528
        if ScalingMode(scaling_mode).is_block_scaling:
529
530
            scale_inv_spec = x_spec

531
        if q_layout.has_colwise:
532
533
534
535
536
537
538
539
540
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
541
542

        scale_inv_sharding = NamedSharding(
543
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
544
        )
545
        amax_sharding = NamedSharding(
546
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
547
        )
548
        colwise_scale_inv_sharding = NamedSharding(
549
            mesh,
550
            PartitionSpec(*colwise_scale_inv_spec),
551
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
552
        )
553

554
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
555
556
557
558
559
560
561
562
        if len(sr_rng_state_spec) > 1:
            # sr_rng_state shape [n_devices, state_per_device]
            sr_rng_state_spec = (*tuple(x for x in x_spec if x is not None), None)
            arg_shardings[3] = NamedSharding(
                mesh,
                PartitionSpec(*sr_rng_state_spec),
                desc="BaseDBiasQuantizePrimitive.sr_rng_state",
            )
563
        arg_shardings = tuple(arg_shardings)
564
565
566
567
568
569
570
571
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
572

573
        def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
574
575
576
            if sr_rng_state.size > 4:
                # See comment in abstract method for explanation of why we cannot assert exact shape
                sr_rng_state = sr_rng_state.flatten()[:4]
577
578
579
580
581
582
583
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
584
            ) = BaseDBiasQuantizePrimitive.impl(
585
586
                x,
                scale,
587
                amax,
588
589
590
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
591
592
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
593
594
                q_layout=q_layout,
                flatten_axis=flatten_axis,
595
596
597
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
598
599
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
600
            )
601

602
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
620
621
622

        return mesh, sharded_impl, out_shardings, arg_shardings

623
624
625
626
627
628
629
630
631
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
632
633
        stochastic_rounding,
        use_rht,
634
635
636
637
        mesh,
        value_types,
        result_types,
    ):
638
639
640
641
642
643
644
645
646
        del (
            out_dtype,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
            mesh,
            result_types,
        )
647

648
        prefix = "DBiasQuantize"
649
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
650
            value_types[0].shape,
651
            unique_var=prefix,
652
            flatten_axis=flatten_axis,
653
            q_layout=q_layout,
654
            broadcast_2d_scale_shape_to_1d=True,
655
656
        )

657
658
659
660
        input_spec = scale_rules.input_spec
        dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",)
        amax = (BATCHING + prefix + "_amax",)
        scale = (BATCHING + prefix + "_scale",)
661
662
663
664
665
666
        sr_rng_state = (BATCHING + prefix + "_sr_rng_state",)
        if value_types[3].shape != [0]:
            sr_rng_state = (
                BATCHING + prefix + "_sr_rng_state_devices",
                prefix + "sr_rng_state_data",
            )
667

668
669
        post_rht_amax = (BATCHING + prefix + "_post_rht_amax",)
        rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2")
670
671

        return SdyShardingRule(
672
673
674
675
676
677
678
679
680
            (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix),
            (
                scale_rules.rowwise_out_spec,
                scale_rules.colwise_out_spec,
                scale_rules.rowwise_scale_spec,
                scale_rules.colwise_scale_spec,
                amax,
                dbias,
            ),
681
            **scale_rules.factor_sizes,
682
683
        )

684

685
686
687
688
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
689
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
690
691
692


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
693
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
694
695


696
697
698
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
699
    if quantizer is None:
700
701
702
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
703
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
704
705


706
707
708
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
709
710
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
711
    dtype = dtype or dx.dtype
712
    dbias = jnp.sum(
713
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
714
        axis=tuple(range(sum_axis)),
715
716
        keepdims=False,
    )
717
    return dbias.astype(dtype)
718
719
720
721
722
723


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
724
    flatten_axis: int = -1,
725
726
):
    if quantizer is None:
727
728
729
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
730
731
732
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
733
734
735
    )


736
def _quantize_dbias_impl(
737
    x: Union[jnp.ndarray, NoScaleTensor],
738
739
740
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
741
    flatten_axis: int = -1,
742
    amax_scope: AmaxScope = AmaxScope.LOCAL,  # Only works when using current-scaling
743
    transpose_batch_sequence: bool = False,
744
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
745
746
747
748
    """
    Cast wrapper
    Return FP8 tensor
    """
749
750
751
752
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

753
754
755
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
756
    # Early-exit for non-quantized call
757
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
758
759
    if quantizer is None:
        dbias = None
760
        if is_dbias:
761
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
762
        return x, dbias
763

Alp Dener's avatar
Alp Dener committed
764
765
766
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
767
    is_unsupported = quantizer.q_layout.is_colwise_only and not (
768
769
770
        quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
        and hasattr(quantizer, "use_rht")
        and quantizer.use_rht
771
772
    )
    if is_unsupported or not PrimitiveClass.enabled():
773
774
775
776
777
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
778
                flatten_axis=flatten_axis,
779
            )
780
781
782
783
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
784

Alp Dener's avatar
Alp Dener committed
785
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
786
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
787
        out, _ = _quantize_dbias_impl(
788
789
790
791
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
792
            flatten_axis=flatten_axis,
793
            amax_scope=amax_scope,
794
            transpose_batch_sequence=transpose_batch_sequence,
795
        )
796
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
797
798
        return out, dbias

799
800
    use_rht = False

801
    scale = jnp.empty((1,), jnp.float32)
802
803
804
805
    post_rht_amax = None
    rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
    amax = x.amax

806
    if hasattr(quantizer, "use_rht") and quantizer.use_rht:
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
        use_rht = True
        rht_matrix = get_rht_matrix()

        new_amax, post_rht_amax = calculate_post_rht_amax(
            x.data,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            produce_regular_amax=amax is None,
            flatten_axis=flatten_axis,
        )
        if amax is None:
            # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel
            # So here we only calculate and update amax when it is not provided from a previous layer (amax is None)
            amax = new_amax

822
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
823
        if amax is None:
824
            amax = calculate_amax(
825
826
                x.data,
                amax_scope=amax_scope,
827
                transpose_batch_sequence=transpose_batch_sequence,
828
            )
829
        scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0)
Alp Dener's avatar
Alp Dener committed
830
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
831
        scale = quantizer.scale
832
833
834
835
836
837
838
839
840
        # Make sure to reset amax to zeros for DelayedScaling
        amax = jnp.zeros((1,), jnp.float32)
    elif quantizer.scaling_mode.is_nvfp4_scaling:
        if amax is None:
            amax = calculate_amax(
                x.data,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
            )
841

842
    # Make sure amax is not None
843
844
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
845

846
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
847
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
848
849
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
850
        and quantizer.q_layout.is_rowwise_colwise
851
852
        and is_1x_kernel_supported
    )
853
    q_layout = quantizer.q_layout
854

855
856
857
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

858
    sr_rng_state = jnp.empty((0,), jnp.uint32)
859
860
861
862
863
    if quantizer.scaling_mode.is_nvfp4_scaling:
        # Only NVFP4 scaling modes support stochastic rounding
        if quantizer.stochastic_rounding_rng_state is not None:
            sr_rng_state = quantizer.stochastic_rounding_rng_state

864
865
866
867
868
869
870
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
871
    ) = PrimitiveClass.outer_primitive.bind(
872
        x.data,
873
        scale,
874
        amax,
875
        sr_rng_state,
876
877
        post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
        rht_matrix,
878
879
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
880
        q_layout=q_layout,
881
        flatten_axis=flatten_axis,
882
        scale_dtype=quantizer.get_scale_dtype(),
883
        is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
884
        is_outer=True,
885
        stochastic_rounding=sr_rng_state.size != 0,
886
        use_rht=use_rht,
887
888
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
889
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
890
891
        colwise_scale_inv = rowwise_scale_inv

892
        if q_layout.is_rowwise_only:
893
894
895
896
897
898
899
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )
900
    quantizer.update(updated_amax)
901
902
    if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias:
        dbias = _jax_dbias(x, flatten_axis=flatten_axis)
903
904
905
906
907
908

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
909
910
        amax=updated_amax,
        colwise_amax=post_rht_amax,
911
        scaling_mode=quantizer.scaling_mode,
912
913
914
915
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
916
        colwise_has_rht_applied=use_rht,
917
    )
918
    return out, dbias.astype(dq_dtype)
919
920
921


def quantize(
922
    x: Union[jnp.ndarray, NoScaleTensor],
923
    quantizer: Quantizer,
924
    flatten_axis: int = -1,
925
    amax_scope: AmaxScope = AmaxScope.LOCAL,
926
    transpose_batch_sequence: bool = False,
927
928
929
930
931
932
933
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
934
935
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
936
            is None.
937
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
938
939
940
941

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
942
    out, _ = _quantize_dbias_impl(
943
944
        x,
        quantizer=quantizer,
945
        flatten_axis=flatten_axis,
946
        amax_scope=amax_scope,
947
        transpose_batch_sequence=transpose_batch_sequence,
948
949
950
951
952
    )
    return out


def quantize_dbias(
953
    dz: Union[jnp.ndarray, NoScaleTensor],
954
955
    quantizer: Quantizer,
    is_dbias: bool = True,
956
    flatten_axis: int = -1,
957
    amax_scope: AmaxScope = AmaxScope.LOCAL,
958
    transpose_batch_sequence: bool = False,
959
960
961
962
963
964
965
966
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
967
968
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
969
970
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.

971
972
973
974
975
976
977
978

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
979
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
980
981
982
983
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
984
        amax_scope=amax_scope,
985
        transpose_batch_sequence=transpose_batch_sequence,
986
    )
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

1029
1030
1031
1032
1033
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

1044
        if q_layout.has_rowwise:
1045
1046
1047
1048
1049
1050
1051
1052
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

1053
        if q_layout.has_colwise:
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
1086
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
1118
            q_layout=q_layout.value.value,
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
1165
    amax: jnp.ndarray = None,
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
1178
        amax: The amax of x; if None, it is auto-generated. (default: None)
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
1193
1194
1195
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
1223
1224
1225
1226
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
1227
1228
1229
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
1230
1231
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
1232
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0)
1233
1234
1235
1236
1237
1238
1239
1240
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
1241
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
1255
        q_layout=q_layout,
1256
1257
1258
1259
1260
1261
1262
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
1263
    if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war:
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1302
1303
1304
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1305
1306
1307
1308
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias