"src/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "3c31dc6cc547548cd56095d9467d409fb0ec5ce4"
quantization.py 43 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9

10

11
import jax
12
import jax.numpy as jnp
13
from jax import dtypes, ffi
14
from jax.experimental.custom_partitioning import SdyShardingRule
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18

19
from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax
20
21
22
23
from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
24
    te_dtype_to_jax_dtype,
25
    jax_dtype_to_te_dtype,
26
27
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
28
    get_min_device_compute_capability,
29
    NamedSharding,
30
)
31
32
33
from ..sharding import (
    all_reduce_max_along_all_axes_except_PP,
    all_reduce_sum_along_dp_fsdp,
34
    num_of_devices,
35
)
36
from ..quantize import (
37
38
39
40
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
41
    Quantizer,
42
    GroupedQuantizer,
43
44
45
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
46
    NoScaleTensor,
47
48
    get_rht_matrix,
    should_use_rht,
49
)
50
51


52
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
53
54


55
class BaseDBiasQuantizePrimitive(BasePrimitive):
56
    """
57
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
58
59
    """

60
    name = "te_dbias_quantize_ffi"
61
    multiple_results = True
62
    impl_static_args = (
63
64
65
66
67
68
69
70
71
72
        6,  # out_dtype
        7,  # scaling_mode
        8,  # q_layout
        9,  # flatten_axis
        10,  # scale_dtype
        11,  # is_dbias
        12,  # is_outer
        13,  # stochastic_rounding
        14,  # use_rht
    )
73
74
75
76
    inner_primitive = None
    outer_primitive = None

    @staticmethod
77
78
79
    def abstract(
        x_aval,
        scale_aval,
80
        amax_aval,
81
82
83
        sr_rng_state_aval,
        post_rht_amax_aval,
        rht_matrix_aval,
84
85
86
        *,
        out_dtype,
        scaling_mode,
87
88
        q_layout,
        flatten_axis,
89
90
91
        scale_dtype,
        is_dbias,
        is_outer,
92
93
        stochastic_rounding,
        use_rht,
94
    ):
95
        """
96
        te_dbias_quantize_p abstract
97
98
99
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
100
        out_shape = x_aval.shape
101
        assert scale_aval is None or scale_aval.dtype == jnp.float32
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        if stochastic_rounding:
            assert ScalingMode(
                scaling_mode
            ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes"
            # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64
            assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, (
                "sr_rng_state must be a uint32 array when stochastic_rounding is True but"
                f" received {sr_rng_state_aval}"
            )
            if is_outer:
                assert (
                    sr_rng_state_aval.shape[0] == num_of_devices()
                    and sr_rng_state_aval.shape[1] == 4
                ), (
                    "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
                    f" True and is_outer is True but received {sr_rng_state_aval.shape}"
                )
            else:
                assert sr_rng_state_aval.shape == (4,), (
                    "Sharded sr_rng_state must be of shape (4,) per device when"
                    f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
                )
124

125
126
127
128
129
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
130

131
132
133
134
135
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

136
        updated_amax_aval = amax_aval
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if use_rht:
            assert (
                x_aval.dtype == jnp.bfloat16
            ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion."

            if flatten_axis < 0:
                flatten_axis += len(x_aval.shape)
            rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1)
            cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
            assert rows % 64 == 0 and cols % 128 == 0, (
                "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be"
                f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape"
                f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}."
            )

            assert (
                rht_matrix_aval is not None
                and rht_matrix_aval.dtype == jnp.bfloat16
                and rht_matrix_aval.shape == (16, 16)
            ), "rht_matrix must be of shape (16, 16) and dtype bfloat16"
            assert (
                post_rht_amax_aval is not None
                and post_rht_amax_aval.dtype == jnp.float32
                and post_rht_amax_aval.size == 1
            ), "post_rht_amax must be of dtype float32"

164
165
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
166
167
168
169
170
171
        ).get_scale_shape_2x(
            x_aval.shape,
            is_padded=not is_outer,
            flatten_axis=flatten_axis,
            broadcast_2d_scale_shape_to_1d=True,
        )
172

173
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
174
            if ScalingMode(scaling_mode).is_colwise_transposed:
175
176
177
178
179
180
181
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
182
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
183
184
185
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
186
187

        if is_dbias:
188
189
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
190
191
192
193
194
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
195
                jax_dtype_to_te_dtype(scale_dtype),
196
197
198
199
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
200
            )
201
202
203
204
205
206
207
208
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
209
210
211
212
213
214
215
216
217
218

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
219
220

    @staticmethod
221
    def outer_abstract(*args, **kwargs):
222
        """
223
        te_dbias_quantize_p outer primitive abstract
224
        """
225
226
227
228
229
230
231
232
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
233
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
234
235
236
237
238
239
240
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
241
        amax,
242
243
244
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
245
246
247
        *,
        out_dtype,
        scaling_mode,
248
249
        q_layout,
        flatten_axis,
250
251
252
        scale_dtype,
        is_dbias,
        is_outer,
253
254
        stochastic_rounding,
        use_rht,
255
256
257
258
    ):
        """
        te_dbias_quantize_p lowering rules
        """
259
        del out_dtype, scale_dtype, is_outer
260
        x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in
261
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
262
263
264
265
266
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
267
268
269
            ctx,
            x,
            scale,
270
            amax,
271
272
273
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
274
            scaling_mode=scaling_mode.value,
275
276
            q_layout=q_layout,
            flatten_axis=flatten_axis,
277
            is_dbias=is_dbias,
278
279
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
280
        )
281
282

    @staticmethod
283
284
285
    def impl(
        x,
        scale,
286
        amax,
287
288
289
        sr_rng_state,
        post_rht_amax,
        rht_matrix,
290
291
        out_dtype,
        scaling_mode,
292
293
        q_layout,
        flatten_axis,
294
295
296
        scale_dtype,
        is_dbias,
        is_outer,
297
298
        stochastic_rounding,
        use_rht,
299
    ):
300
        """
301
        te_dbias_quantize_p implementation
302
        """
303
        del is_outer
304
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
305
306
307
308
309
310
311
312
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
313
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
314
315
            x,
            scale,
316
            amax,
317
318
319
            sr_rng_state,
            post_rht_amax,
            rht_matrix,
320
321
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
322
323
            q_layout=q_layout,
            flatten_axis=flatten_axis,
324
325
326
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
327
328
            stochastic_rounding=stochastic_rounding,
            use_rht=use_rht,
329
        )
330
331
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
332
333
334
        ).get_scale_shape_2x(
            x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True
        )
335
336
337
338
339
340
341
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
342
343
344
345
346
347
348
349
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
350
351

    @staticmethod
352
353
354
355
356
357
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
358
359
        q_layout,
        flatten_axis,
360
361
362
        scale_dtype,
        is_dbias,
        is_outer,
363
364
        stochastic_rounding,
        use_rht,
365
366
367
368
369
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
370
        check_valid_batch_dims(batch_dims)
371
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
372
373
        x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args
        x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims
374

375
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
376
        return (
377
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
378
379
                x,
                scale,
380
                amax,
381
382
383
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
384
385
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
386
387
                q_layout=q_layout,
                flatten_axis=flatten_axis,
388
389
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
390
391
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
392
            ),
393
394
            out_bdims,
        )
395
396

    @staticmethod
397
398
399
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
400
401
        q_layout,
        flatten_axis,
402
403
404
        scale_dtype,
        is_dbias,
        is_outer,
405
406
        stochastic_rounding,
        use_rht,
407
408
409
410
        mesh,
        arg_infos,
        result_infos,
    ):
411
412
413
414
415
416
417
418
        del (
            out_dtype,
            result_infos,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
        )  # Unused.
419

420
        x_spec = get_padded_spec(arg_infos[0])
421
        amax_spec = get_padded_spec(arg_infos[2])
422
423
        out_sharding = NamedSharding(
            mesh,
424
            PartitionSpec(*x_spec),
425
            desc="BaseDBiasQuantizePrimitive.out_sharding",
426
        )
427
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
428
            if ScalingMode(scaling_mode).is_colwise_transposed:
429
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
430
431
432
433
434
435
436
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
437
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
438
        )
439
440
441

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
442
            mesh,
443
            PartitionSpec(*dbias_spec),
444
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
445
        )
446

447
        scale_inv_spec = colwise_scale_inv_spec = (None,)
448
        if ScalingMode(scaling_mode).is_block_scaling:
449
450
451
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
452
453
454
455
456
457
458
459
460
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
461
462

        scale_inv_sharding = NamedSharding(
463
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
464
        )
465
        colwise_scale_inv_sharding = NamedSharding(
466
            mesh,
467
            PartitionSpec(*colwise_scale_inv_spec),
468
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
469
        )
470
471
472
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
473

474
475
476
477
478
479
480
481
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
482
483

    @staticmethod
484
485
486
    def partition(
        out_dtype,
        scaling_mode,
487
488
        q_layout,
        flatten_axis,
489
490
491
        scale_dtype,
        is_dbias,
        is_outer,
492
493
        stochastic_rounding,
        use_rht,
494
495
496
497
        mesh,
        arg_infos,
        result_infos,
    ):
498
        del result_infos, is_outer  # Unused.
499

500
        x_spec = get_padded_spec(arg_infos[0])
501
        amax_spec = get_padded_spec(arg_infos[2])
502
503
        out_sharding = NamedSharding(
            mesh,
504
            PartitionSpec(*x_spec),
505
            desc="BaseDBiasQuantizePrimitive.out_sharding",
506
        )
507

508
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
509
            if ScalingMode(scaling_mode).is_colwise_transposed:
510
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
511
512
513
514
515
516
517
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
518
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
519
        )
520
521
522

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
523
            mesh,
524
            PartitionSpec(*dbias_spec),
525
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
526
        )
527

528
        scale_inv_spec = colwise_scale_inv_spec = (None,)
529
        if ScalingMode(scaling_mode).is_block_scaling:
530
531
532
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
533
534
535
536
537
538
539
540
541
            if (
                ScalingMode(scaling_mode).is_block_scaling
                and ScalingMode(scaling_mode).is_colwise_transposed
            ):
                colwise_scale_inv_spec = multidim_transpose(
                    scale_inv_spec, transpose_axis=flatten_axis
                )
            else:
                colwise_scale_inv_spec = scale_inv_spec
542
543

        scale_inv_sharding = NamedSharding(
544
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
545
        )
546
        amax_sharding = NamedSharding(
547
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
548
        )
549
        colwise_scale_inv_sharding = NamedSharding(
550
            mesh,
551
            PartitionSpec(*colwise_scale_inv_spec),
552
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
553
        )
554

555
        # TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes
556
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
557
558
559
560
561
562
563
564
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
565

566
        def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
567
568
569
570
571
572
573
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
574
            ) = BaseDBiasQuantizePrimitive.impl(
575
576
                x,
                scale,
577
                amax,
578
579
580
                sr_rng_state,
                post_rht_amax,
                rht_matrix,
581
582
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
583
584
                q_layout=q_layout,
                flatten_axis=flatten_axis,
585
586
587
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
588
589
                stochastic_rounding=stochastic_rounding,
                use_rht=use_rht,
590
            )
591

592
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
610
611
612

        return mesh, sharded_impl, out_shardings, arg_shardings

613
614
615
616
617
618
619
620
621
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
622
623
        stochastic_rounding,
        use_rht,
624
625
626
627
        mesh,
        value_types,
        result_types,
    ):
628
629
630
631
632
633
634
635
636
        del (
            out_dtype,
            scale_dtype,
            is_outer,
            stochastic_rounding,
            use_rht,
            mesh,
            result_types,
        )
637

638
        prefix = "DBiasQuantize_"
639
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
640
            value_types[0].shape,
Alp Dener's avatar
Alp Dener committed
641
            unique_var=prefix + "x",
642
            flatten_axis=flatten_axis,
643
            broadcast_2d_scale_shape_to_1d=True,
644
645
646
647
648
        )

        x_axes = scale_rules.input_spec

        out = x_axes
Alp Dener's avatar
Alp Dener committed
649
        colwise_out = (prefix + "out_colwise",)
650
        colwise_scale_inv = (prefix + "colwise_scale_inv",)
651
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
652
653
            colwise_scale_inv = scale_rules.colwise_rule
            if ScalingMode(scaling_mode).is_colwise_transposed:
654
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
655
656
657
                colwise_scale_inv = tuple(
                    multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
                )
658
659
660
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
661
662
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
663
664
665
666
        sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")

        post_rht_amax = (prefix + "post_rht_amax",)
        rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2")
667
668

        return SdyShardingRule(
669
            (x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix),
670
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
671
            **scale_rules.factor_sizes,
672
673
        )

674

675
676
677
678
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
679
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
680
681
682


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
683
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
684
685


686
687
688
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
689
    if quantizer is None:
690
691
692
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
693
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
694
695


696
697
698
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
699
700
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
701
    dtype = dtype or dx.dtype
702
    dbias = jnp.sum(
703
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
704
        axis=tuple(range(sum_axis)),
705
706
        keepdims=False,
    )
707
    return dbias.astype(dtype)
708
709
710
711
712
713


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
714
    flatten_axis: int = -1,
715
716
):
    if quantizer is None:
717
718
719
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
720
721
722
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
723
724
725
    )


726
def _quantize_dbias_impl(
727
    x: Union[jnp.ndarray, NoScaleTensor],
728
729
730
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
731
    flatten_axis: int = -1,
732
    amax_scope: AmaxScope = AmaxScope.LOCAL,  # Only works when using current-scaling
733
    transpose_batch_sequence: bool = False,
734
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
735
736
737
738
    """
    Cast wrapper
    Return FP8 tensor
    """
739
740
741
742
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

743
744
745
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
746
    # Early-exit for non-quantized call
747
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
748
749
    if quantizer is None:
        dbias = None
750
        if is_dbias:
751
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
752
        return x, dbias
753

Alp Dener's avatar
Alp Dener committed
754
755
756
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
757
758
759
760
761
    is_unsupported = (
        quantizer.q_layout == QuantizeLayout.COLWISE
        and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING
    )
    if is_unsupported or not PrimitiveClass.enabled():
762
763
764
765
766
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
767
                flatten_axis=flatten_axis,
768
            )
769
770
771
772
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
773

Alp Dener's avatar
Alp Dener committed
774
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
775
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
776
        out, _ = _quantize_dbias_impl(
777
778
779
780
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
781
            flatten_axis=flatten_axis,
782
            amax_scope=amax_scope,
783
            transpose_batch_sequence=transpose_batch_sequence,
784
        )
785
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
786
787
        return out, dbias

788
789
    use_rht = False

790
    scale = jnp.empty((1,), jnp.float32)
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
    post_rht_amax = None
    rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
    amax = x.amax

    if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout):
        use_rht = True
        rht_matrix = get_rht_matrix()

        new_amax, post_rht_amax = calculate_post_rht_amax(
            x.data,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            produce_regular_amax=amax is None,
            flatten_axis=flatten_axis,
        )
        if amax is None:
            # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel
            # So here we only calculate and update amax when it is not provided from a previous layer (amax is None)
            amax = new_amax

811
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
812
        if amax is None:
813
            amax = calculate_amax(
814
815
                x.data,
                amax_scope=amax_scope,
816
                transpose_batch_sequence=transpose_batch_sequence,
817
            )
818
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
819
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
820
        scale = quantizer.scale
821
822
823
824
825
826
827
828
829
        # Make sure to reset amax to zeros for DelayedScaling
        amax = jnp.zeros((1,), jnp.float32)
    elif quantizer.scaling_mode.is_nvfp4_scaling:
        if amax is None:
            amax = calculate_amax(
                x.data,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
            )
830

831
    # Make sure amax is not None
832
833
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
834

835
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
836
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
837
838
839
840
841
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
842
    q_layout = quantizer.q_layout
843

844
845
846
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

847
848
849
850
851
852
    sr_rng_state = None
    if quantizer.scaling_mode.is_nvfp4_scaling:
        # Only NVFP4 scaling modes support stochastic rounding
        if quantizer.stochastic_rounding_rng_state is not None:
            sr_rng_state = quantizer.stochastic_rounding_rng_state

853
854
855
856
857
858
859
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
860
    ) = PrimitiveClass.outer_primitive.bind(
861
        x.data,
862
        scale,
863
        amax,
864
865
866
        sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32),
        post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
        rht_matrix,
867
868
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
869
        q_layout=q_layout.value,
870
        flatten_axis=flatten_axis,
871
        scale_dtype=quantizer.get_scale_dtype(),
872
        is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
873
        is_outer=True,
874
875
        stochastic_rounding=sr_rng_state is not None,
        use_rht=use_rht,
876
877
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
878
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
879
880
        colwise_scale_inv = rowwise_scale_inv

881
882
883
884
885
886
887
888
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )
889
    quantizer.update(updated_amax)
890
891
    if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias:
        dbias = _jax_dbias(x, flatten_axis=flatten_axis)
892
893
894
895
896
897

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
898
899
        amax=updated_amax,
        colwise_amax=post_rht_amax,
900
        scaling_mode=quantizer.scaling_mode,
901
902
903
904
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
905
    )
906
    return out, dbias.astype(dq_dtype)
907
908
909


def quantize(
910
    x: Union[jnp.ndarray, NoScaleTensor],
911
    quantizer: Quantizer,
912
    flatten_axis: int = -1,
913
    amax_scope: AmaxScope = AmaxScope.LOCAL,
914
    transpose_batch_sequence: bool = False,
915
916
917
918
919
920
921
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
922
923
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
924
            is None.
925
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
926
927
928
929

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
930
    out, _ = _quantize_dbias_impl(
931
932
        x,
        quantizer=quantizer,
933
        flatten_axis=flatten_axis,
934
        amax_scope=amax_scope,
935
        transpose_batch_sequence=transpose_batch_sequence,
936
937
938
939
940
    )
    return out


def quantize_dbias(
941
    dz: Union[jnp.ndarray, NoScaleTensor],
942
943
    quantizer: Quantizer,
    is_dbias: bool = True,
944
    flatten_axis: int = -1,
945
    amax_scope: AmaxScope = AmaxScope.LOCAL,
946
    transpose_batch_sequence: bool = False,
947
948
949
950
951
952
953
954
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
955
956
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
957
958
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.

959
960
961
962
963
964
965
966

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
967
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
968
969
970
971
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
972
        amax_scope=amax_scope,
973
        transpose_batch_sequence=transpose_batch_sequence,
974
    )
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

1017
1018
1019
1020
1021
        assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
            f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
            f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
        )

1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
1074
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
1153
    amax: jnp.ndarray = None,
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
1166
        amax: The amax of x; if None, it is auto-generated. (default: None)
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
1181
1182
1183
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
1211
1212
1213
1214
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
1215
1216
1217
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1290
1291
1292
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1293
1294
1295
1296
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias