quantization.py 43.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9

10

11
import jax
12
import jax.numpy as jnp
13
from jax import dtypes, ffi
14
from jax.experimental.custom_partitioning import SdyShardingRule
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18

19
from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax
20
21
22
23
from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
24
    te_dtype_to_jax_dtype,
25
    jax_dtype_to_te_dtype,
26
27
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
28
    get_min_device_compute_capability,
29
    NamedSharding,
30
)
31
32
33
from ..sharding import (
    all_reduce_max_along_all_axes_except_PP,
    all_reduce_sum_along_dp_fsdp,
34
    get_num_devices_in_mesh,
35
)
36
from ..quantize import (
37
38
39
40
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
41
    Quantizer,
42
    GroupedQuantizer,
43
44
45
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
46
    NoScaleTensor,
47
    get_rht_matrix,
48
)
49
50


51
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
52
53


54
class BaseDBiasQuantizePrimitive(BasePrimitive):
55
    """
56
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
57
58
    """

59
    name = "te_dbias_quantize_ffi"
60
    multiple_results = True
61
    impl_static_args = (
62
63
64
65
66
67
68
69
70
71
        6,  # out_dtype
        7,  # scaling_mode
        8,  # q_layout
        9,  # flatten_axis
        10,  # scale_dtype
        11,  # is_dbias
        12,  # is_outer
        13,  # stochastic_rounding
        14,  # use_rht
    )
72
73
74
75
    inner_primitive = None
    outer_primitive = None

    @staticmethod
76
77
78
    def abstract(
        x_aval,
        scale_aval,
79
        amax_aval,
80
81
82
        sr_rng_state_aval,
        post_rht_amax_aval,
        rht_matrix_aval,
83
84
85
        *,
        out_dtype,
        scaling_mode,
86
87
        q_layout,
        flatten_axis,
88
89
90
        scale_dtype,
        is_dbias,
        is_outer,
91
92
        stochastic_rounding,
        use_rht,
93
    ):
94
        """
95
        te_dbias_quantize_p abstract
96
97
98
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
99
        out_shape = x_aval.shape
100
        assert scale_aval is None or scale_aval.dtype == jnp.float32
101
102
103
104
105
106
107
108
109
        if stochastic_rounding:
            assert ScalingMode(
                scaling_mode
            ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes"
            # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64
            assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, (
                "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
                f" received {sr_rng_state_aval}"
            )
110
            if is_outer and get_num_devices_in_mesh() > 1:
111
                assert (
112
                    sr_rng_state_aval.shape[0] == get_num_devices_in_mesh()
113
114
115
116
117
118
                    and sr_rng_state_aval.shape[1] == 4
                ), (
                    "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
                    f" True and is_outer is True but received {sr_rng_state_aval.shape}"
                )
            else:
119
120
121
                # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
                assert sr_rng_state_aval.size >= 4, (
                    "Sharded sr_rng_state must have at least 4 elements per device when"
122
123
                    f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
                )
124

125
126
127
128
129
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
130

131
132
133
134
135
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

136
        updated_amax_aval = amax_aval
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if use_rht:
            assert (
                x_aval.dtype == jnp.bfloat16
            ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion."

            if flatten_axis < 0:
                flatten_axis += len(x_aval.shape)
            rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1)
            cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
            assert rows % 64 == 0 and cols % 128 == 0, (
                "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be"
                f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape"
                f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}."
            )

            assert (
                rht_matrix_aval is not None
                and rht_matrix_aval.dtype == jnp.bfloat16
                and rht_matrix_aval.shape == (16, 16)
            ), "rht_matrix must be of shape (16, 16) and dtype bfloat16"
            assert (
                post_rht_amax_aval is not None
                and post_rht_amax_aval.dtype == jnp.float32
                and post_rht_amax_aval.size == 1
            ), "post_rht_amax must be of dtype float32"

164
165
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
166
167
168
169
170
171
        ).get_scale_shape_2x(
            x_aval.shape,
            is_padded=not is_outer,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=True,
        )
172

173
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
174
            if ScalingMode(scaling_mode).is_colwise_transposed:
175
176
177
178
179
180
181
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
182
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
183
184
185
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
186
187

        if is_dbias:
188
189
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
190
191
192
193
194
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
195
                jax_dtype_to_te_dtype(scale_dtype),
196
197
198
199
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
200
            )
201
202
203
204
205
206
207
208
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
209
210
211
212
213
214
215
216
217
218

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
219
220

    @staticmethod
221
    def outer_abstract(*args, **kwargs):
222
        """
223
        te_dbias_quantize_p outer primitive abstract
224
        """
225
226
227
228
229
230
231
232
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
233
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
234
235
236
237
238
239
240
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
241
        amax,
242
243
244
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
245
246
247
        *,
        out_dtype,
        scaling_mode,
248
249
        q_layout,
        flatten_axis,
250
251
252
        scale_dtype,
        is_dbias,
        is_outer,
253
254
        stochastic_rounding,
        use_rht,
255
256
257
258
    ):
        """
        te_dbias_quantize_p lowering rules
        """
259
        del out_dtype, scale_dtype, is_outer
260
        x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in
261
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
262
263
264
265
266
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
267
268
269
            ctx,
            x,
            scale,
270
            amax,
271
272
273
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
274
            scaling_mode=scaling_mode.value,
275
276
            q_layout=q_layout,
            flatten_axis=flatten_axis,
277
            is_dbias=is_dbias,
278
279
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
280
        )
281
282

    @staticmethod
283
284
285
    def impl(
        x,
        scale,
286
        amax,
287
288
289
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
290
291
        out_dtype,
        scaling_mode,
292
293
        q_layout,
        flatten_axis,
294
295
296
        scale_dtype,
        is_dbias,
        is_outer,
297
298
        stochastic_rounding,
        use_rht,
299
    ):
300
        """
301
        te_dbias_quantize_p implementation
302
        """
303
        del is_outer
304
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
305
306
307
308
309
310
311
312
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
313
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
314
315
            x,
            scale,
316
            amax,
317
318
319
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
320
321
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
322
323
            q_layout=q_layout,
            flatten_axis=flatten_axis,
324
325
326
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
327
328
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
329
        )
330
331
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
332
333
334
        ).get_scale_shape_2x(
            x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True
        )
335
336
337
338
339
340
341
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
342
343
344
345
346
347
348
349
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
350
351

    @staticmethod
352
353
354
355
356
357
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
358
359
        q_layout,
        flatten_axis,
360
361
362
        scale_dtype,
        is_dbias,
        is_outer,
363
364
        stochastic_rounding,
        use_rht,
365
366
367
368
369
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
370
        check_valid_batch_dims(batch_dims)
371
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
372
373
        x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args
        x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims
374

375
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
376
        return (
377
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
378
379
                x,
                scale,
380
                amax,
381
382
383
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
384
385
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
386
387
                q_layout=q_layout,
                flatten_axis=flatten_axis,
388
389
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
390
391
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
392
            ),
393
394
            out_bdims,
        )
395
396

    @staticmethod
397
398
399
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
400
401
        q_layout,
        flatten_axis,
402
403
404
        scale_dtype,
        is_dbias,
        is_outer,
405
406
        stochastic_rounding,
        use_rht,
407
408
409
410
        mesh,
        arg_infos,
        result_infos,
    ):
411
412
413
414
415
416
417
418
        del (
            out_dtype,
            result_infos,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
        )  # Unused.
419

420
        x_spec = get_padded_spec(arg_infos[0])
421
        amax_spec = get_padded_spec(arg_infos[2])
422
423
        out_sharding = NamedSharding(
            mesh,
424
            PartitionSpec(*x_spec),
425
            desc="BaseDBiasQuantizePrimitive.out_sharding",
426
        )
427
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
428
            if ScalingMode(scaling_mode).is_colwise_transposed:
429
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
430
431
432
433
434
435
436
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
437
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
438
        )
439
440
441

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
442
            mesh,
443
            PartitionSpec(*dbias_spec),
444
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
445
        )
446

447
        scale_inv_spec = colwise_scale_inv_spec = (None,)
448
        if ScalingMode(scaling_mode).is_block_scaling:
449
450
451
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
452
453
454
455
456
457
458
459
460
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
461
462

        scale_inv_sharding = NamedSharding(
463
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
464
        )
465
        colwise_scale_inv_sharding = NamedSharding(
466
            mesh,
467
            PartitionSpec(*colwise_scale_inv_spec),
468
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
469
        )
470
471
472
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
473

474
475
476
477
478
479
480
481
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
482
483

    @staticmethod
484
485
486
    def partition(
        out_dtype,
        scaling_mode,
487
488
        q_layout,
        flatten_axis,
489
490
491
        scale_dtype,
        is_dbias,
        is_outer,
492
493
        stochastic_rounding,
        use_rht,
494
495
496
497
        mesh,
        arg_infos,
        result_infos,
    ):
498
        del result_infos, is_outer  # Unused.
499

500
        x_spec = get_padded_spec(arg_infos[0])
501
        amax_spec = get_padded_spec(arg_infos[2])
502
503
        out_sharding = NamedSharding(
            mesh,
504
            PartitionSpec(*x_spec),
505
            desc="BaseDBiasQuantizePrimitive.out_sharding",
506
        )
507

508
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
509
            if ScalingMode(scaling_mode).is_colwise_transposed:
510
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
511
512
513
514
515
516
517
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
518
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
519
        )
520
521
522

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
523
            mesh,
524
            PartitionSpec(*dbias_spec),
525
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
526
        )
527

528
        scale_inv_spec = colwise_scale_inv_spec = (None,)
529
        if ScalingMode(scaling_mode).is_block_scaling:
530
531
532
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
533
534
535
536
537
538
539
540
541
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
542
543

        scale_inv_sharding = NamedSharding(
544
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
545
        )
546
        amax_sharding = NamedSharding(
547
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
548
        )
549
        colwise_scale_inv_sharding = NamedSharding(
550
            mesh,
551
            PartitionSpec(*colwise_scale_inv_spec),
552
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
553
        )
554

555
556
557
558
559
560
561
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        arg_shardings[3] = NamedSharding(
            mesh,
            PartitionSpec(tuple(x for x in x_spec if x is not None), None),
            desc="BaseDBiasQuantizePrimitive.sr_rng_state",
        )
        arg_shardings = tuple(arg_shardings)
562
563
564
565
566
567
568
569
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
570

571
        def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
572
573
574
            if sr_rng_state.size > 4:
                # See comment in abstract method for explanation of why we cannot assert exact shape
                sr_rng_state = sr_rng_state.flatten()[:4]
575
576
577
578
579
580
581
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
582
            ) = BaseDBiasQuantizePrimitive.impl(
583
584
                x,
                scale,
585
                amax,
586
587
588
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
589
590
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
591
592
                q_layout=q_layout,
                flatten_axis=flatten_axis,
593
594
595
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
596
597
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
598
            )
599

600
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
618
619
620

        return mesh, sharded_impl, out_shardings, arg_shardings

621
622
623
624
625
626
627
628
629
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
630
631
        stochastic_rounding,
        use_rht,
632
633
634
635
        mesh,
        value_types,
        result_types,
    ):
636
637
638
639
640
641
642
643
644
        del (
            out_dtype,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
            mesh,
            result_types,
        )
645

646
        prefix = "DBiasQuantize_"
647
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
648
            value_types[0].shape,
Alp Dener's avatar
Alp Dener committed
649
            unique_var=prefix + "x",
650
            flatten_axis=flatten_axis,
651
            broadcast_2d_scale_shape_to_1d=True,
652
653
654
655
656
        )

        x_axes = scale_rules.input_spec

        out = x_axes
Alp Dener's avatar
Alp Dener committed
657
        colwise_out = (prefix + "out_colwise",)
658
        colwise_scale_inv = (prefix + "colwise_scale_inv",)
659
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
660
661
            colwise_scale_inv = scale_rules.colwise_rule
            if ScalingMode(scaling_mode).is_colwise_transposed:
662
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
663
664
665
                colwise_scale_inv = tuple(
                    multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
                )
666
667
668
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
669
670
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
671
672
673
674
        sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")

        post_rht_amax = (prefix + "post_rht_amax",)
        rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2")
675
676

        return SdyShardingRule(
677
            (x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix),
678
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
679
            **scale_rules.factor_sizes,
680
681
        )

682

683
684
685
686
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
687
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
688
689
690


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
691
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
692
693


694
695
696
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
697
    if quantizer is None:
698
699
700
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
701
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
702
703


704
705
706
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
707
708
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
709
    dtype = dtype or dx.dtype
710
    dbias = jnp.sum(
711
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
712
        axis=tuple(range(sum_axis)),
713
714
        keepdims=False,
    )
715
    return dbias.astype(dtype)
716
717
718
719
720
721


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
722
    flatten_axis: int = -1,
723
724
):
    if quantizer is None:
725
726
727
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
728
729
730
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
731
732
733
    )


734
def _quantize_dbias_impl(
735
    x: Union[jnp.ndarray, NoScaleTensor],
736
737
738
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
739
    flatten_axis: int = -1,
740
    amax_scope: AmaxScope = AmaxScope.LOCAL,  # Only works when using current-scaling
741
    transpose_batch_sequence: bool = False,
742
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
743
744
745
746
    """
    Cast wrapper
    Return FP8 tensor
    """
747
748
749
750
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

751
752
753
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
754
    # Early-exit for non-quantized call
755
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
756
757
    if quantizer is None:
        dbias = None
758
        if is_dbias:
759
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
760
        return x, dbias
761

Alp Dener's avatar
Alp Dener committed
762
763
764
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
765
766
767
768
    is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not (
        quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
        and hasattr(quantizer, "use_rht")
        and quantizer.use_rht
769
770
    )
    if is_unsupported or not PrimitiveClass.enabled():
771
772
773
774
775
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
776
                flatten_axis=flatten_axis,
777
            )
778
779
780
781
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
782

Alp Dener's avatar
Alp Dener committed
783
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
784
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
785
        out, _ = _quantize_dbias_impl(
786
787
788
789
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
790
            flatten_axis=flatten_axis,
791
            amax_scope=amax_scope,
792
            transpose_batch_sequence=transpose_batch_sequence,
793
        )
794
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
795
796
        return out, dbias

797
798
    use_rht = False

799
    scale = jnp.empty((1,), jnp.float32)
800
801
802
803
    post_rht_amax = None
    rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
    amax = x.amax

804
    if hasattr(quantizer, "use_rht") and quantizer.use_rht:
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        use_rht = True
        rht_matrix = get_rht_matrix()

        new_amax, post_rht_amax = calculate_post_rht_amax(
            x.data,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            produce_regular_amax=amax is None,
            flatten_axis=flatten_axis,
        )
        if amax is None:
            # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel
            # So here we only calculate and update amax when it is not provided from a previous layer (amax is None)
            amax = new_amax

820
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
821
        if amax is None:
822
            amax = calculate_amax(
823
824
                x.data,
                amax_scope=amax_scope,
825
                transpose_batch_sequence=transpose_batch_sequence,
826
            )
827
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
828
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
829
        scale = quantizer.scale
830
831
832
833
834
835
836
837
838
        # Make sure to reset amax to zeros for DelayedScaling
        amax = jnp.zeros((1,), jnp.float32)
    elif quantizer.scaling_mode.is_nvfp4_scaling:
        if amax is None:
            amax = calculate_amax(
                x.data,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
            )
839

840
    # Make sure amax is not None
841
842
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
843

844
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
845
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
846
847
848
849
850
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
851
    q_layout = quantizer.q_layout
852

853
854
855
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

856
857
858
859
860
861
    sr_rng_state = None
    if quantizer.scaling_mode.is_nvfp4_scaling:
        # Only NVFP4 scaling modes support stochastic rounding
        if quantizer.stochastic_rounding_rng_state is not None:
            sr_rng_state = quantizer.stochastic_rounding_rng_state

862
863
864
865
866
867
868
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
869
    ) = PrimitiveClass.outer_primitive.bind(
870
        x.data,
871
        scale,
872
        amax,
873
874
875
876
877
        (
            sr_rng_state
            if sr_rng_state is not None
            else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32)
        ),
878
879
        post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
        rht_matrix,
880
881
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
882
        q_layout=q_layout.value,
883
        flatten_axis=flatten_axis,
884
        scale_dtype=quantizer.get_scale_dtype(),
885
        is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
886
        is_outer=True,
887
888
        stochastic_rounding=sr_rng_state is not None,
        use_rht=use_rht,
889
890
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
891
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
892
893
        colwise_scale_inv = rowwise_scale_inv

894
895
896
897
898
899
900
901
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )
902
    quantizer.update(updated_amax)
903
904
    if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias:
        dbias = _jax_dbias(x, flatten_axis=flatten_axis)
905
906
907
908
909
910

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
911
912
        amax=updated_amax,
        colwise_amax=post_rht_amax,
913
        scaling_mode=quantizer.scaling_mode,
914
915
916
917
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
918
        colwise_has_rht_applied=use_rht,
919
    )
920
    return out, dbias.astype(dq_dtype)
921
922
923


def quantize(
924
    x: Union[jnp.ndarray, NoScaleTensor],
925
    quantizer: Quantizer,
926
    flatten_axis: int = -1,
927
    amax_scope: AmaxScope = AmaxScope.LOCAL,
928
    transpose_batch_sequence: bool = False,
929
930
931
932
933
934
935
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
936
937
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
938
            is None.
939
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
940
941
942
943

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
944
    out, _ = _quantize_dbias_impl(
945
946
        x,
        quantizer=quantizer,
947
        flatten_axis=flatten_axis,
948
        amax_scope=amax_scope,
949
        transpose_batch_sequence=transpose_batch_sequence,
950
951
952
953
954
    )
    return out


def quantize_dbias(
955
    dz: Union[jnp.ndarray, NoScaleTensor],
956
957
    quantizer: Quantizer,
    is_dbias: bool = True,
958
    flatten_axis: int = -1,
959
    amax_scope: AmaxScope = AmaxScope.LOCAL,
960
    transpose_batch_sequence: bool = False,
961
962
963
964
965
966
967
968
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
969
970
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
971
972
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.

973
974
975
976
977
978
979
980

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
981
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
982
983
984
985
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
986
        amax_scope=amax_scope,
987
        transpose_batch_sequence=transpose_batch_sequence,
988
    )
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

1031
1032
1033
1034
1035
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
1088
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
1167
    amax: jnp.ndarray = None,
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
1180
        amax: The amax of x; if None, it is auto-generated. (default: None)
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
1195
1196
1197
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
1225
1226
1227
1228
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
1229
1230
1231
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1304
1305
1306
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1307
1308
1309
1310
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias