quantization.py 33.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
import math
9
from packaging import version
10

11
import jax
12
13
import jax.numpy as jnp
from jax import dtypes
14
from jax.experimental.custom_partitioning import SdyShardingRule
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18
19
20
21
22

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
23
    te_dtype_to_jax_dtype,
24
    jax_dtype_to_te_dtype,
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
27
    get_min_device_compute_capability,
28
    NamedSharding,
29
)
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
31
from ..quantize import (
32
33
34
35
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
36
    Quantizer,
37
    GroupedQuantizer,
38
39
40
41
42
    QuantizeLayout,
    DelayedScaleQuantizer,
    ScalingMode,
    compute_scale_from_amax,
)
43

44
45
46
47
48
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

49

50
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
51
52


53
class BaseDBiasQuantizePrimitive(BasePrimitive):
54
    """
55
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
56
57
    """

58
    name = "te_dbias_quantize_ffi"
59
    multiple_results = True
60
61
62
63
64
65
66
67
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
68
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
69
70
71
72
    inner_primitive = None
    outer_primitive = None

    @staticmethod
73
74
75
76
77
78
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
79
80
        q_layout,
        flatten_axis,
81
82
83
84
        scale_dtype,
        is_dbias,
        is_outer,
    ):
85
        """
86
        te_dbias_quantize_p abstract
87
88
89
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
90
        out_shape = x_aval.shape
91
        assert scale_aval is None or scale_aval.dtype == jnp.float32
92

93
94
95
96
97
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
98
99
100
101
102

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
103
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
104

105
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
106
            if ScalingMode(scaling_mode).is_tensor_scaling():
107
108
109
110
111
112
113
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
114
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
115
116
117
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
118
119

        if is_dbias:
120
121
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
122
123
124
125
126
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
127
128
129
130
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
131
            )
132
133
134
135
136
137
138
139
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
140
141
142
143
144
145
146
147
148
149

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
150
151

    @staticmethod
152
    def outer_abstract(*args, **kwargs):
153
        """
154
        te_dbias_quantize_p outer primitive abstract
155
        """
156
157
158
159
160
161
162
163
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
164
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
165
166
167
168
169
170
171
172
173
174
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
175
176
        q_layout,
        flatten_axis,
177
178
179
180
181
182
183
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
184
        del out_dtype, scale_dtype, is_outer
185
        x_aval, scale_aval = ctx.avals_in
186
187
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
188
        return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
189
190
191
            ctx,
            x,
            scale,
192
            scaling_mode=scaling_mode.value,
193
194
            q_layout=q_layout,
            flatten_axis=flatten_axis,
195
196
            is_dbias=is_dbias,
        )
197
198

    @staticmethod
199
200
201
202
203
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
204
205
        q_layout,
        flatten_axis,
206
207
208
209
        scale_dtype,
        is_dbias,
        is_outer,
    ):
210
        """
211
        te_dbias_quantize_p implementation
212
        """
213
        del is_outer
214
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
215
216
217
218
219
220
221
222
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
223
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
224
225
226
227
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
228
229
            q_layout=q_layout,
            flatten_axis=flatten_axis,
230
231
232
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
233
        )
234
235
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
236
237
238
239
240
241
242
243
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
244
245
246
247
248
249
250
251
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
252
253

    @staticmethod
254
255
256
257
258
259
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
260
261
        q_layout,
        flatten_axis,
262
263
264
265
266
267
268
269
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
270
        check_valid_batch_dims(batch_dims)
271
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
272
273
274
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
275

276
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
277
        return (
278
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
279
280
281
282
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
283
284
                q_layout=q_layout,
                flatten_axis=flatten_axis,
285
286
287
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
288
289
            out_bdims,
        )
290
291

    @staticmethod
292
293
294
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
295
296
        q_layout,
        flatten_axis,
297
298
299
300
301
302
303
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
304
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
305

306
        x_spec = get_padded_spec(arg_infos[0])
307
        scale_spec = get_padded_spec(arg_infos[1])
308
309
        out_sharding = NamedSharding(
            mesh,
310
            PartitionSpec(*x_spec),
311
            desc="BaseDBiasQuantizePrimitive.out_sharding",
312
        )
313
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
314
            if ScalingMode(scaling_mode).is_tensor_scaling():
315
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
316
317
318
319
320
321
322
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
323
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
324
        )
325
326
327

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
328
            mesh,
329
            PartitionSpec(*dbias_spec),
330
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
331
        )
332
333

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
334
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
335
            scale_inv_spec = amax_spec = scale_spec
336
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
337
338
339
340
341
342
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
343
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
344
        )
345
        amax_sharding = NamedSharding(
346
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
347
        )
348
        colwise_scale_inv_sharding = NamedSharding(
349
            mesh,
350
            PartitionSpec(*colwise_scale_inv_spec),
351
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
352
        )
353

354
355
356
357
358
359
360
361
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
362
363

    @staticmethod
364
365
366
    def partition(
        out_dtype,
        scaling_mode,
367
368
        q_layout,
        flatten_axis,
369
370
371
372
373
374
375
376
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
377

378
        x_spec = get_padded_spec(arg_infos[0])
379
        scale_spec = get_padded_spec(arg_infos[1])
380
381
        out_sharding = NamedSharding(
            mesh,
382
            PartitionSpec(*x_spec),
383
            desc="BaseDBiasQuantizePrimitive.out_sharding",
384
        )
385
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
386
            if ScalingMode(scaling_mode).is_tensor_scaling():
387
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
388
389
390
391
392
393
394
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
395
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
396
        )
397
398
399

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
400
            mesh,
401
            PartitionSpec(*dbias_spec),
402
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
403
        )
404
405

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
406
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
407
            scale_inv_spec = amax_spec = scale_spec
408
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
409
410
411
412
413
414
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
415
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
416
        )
417
        amax_sharding = NamedSharding(
418
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
419
        )
420
        colwise_scale_inv_sharding = NamedSharding(
421
            mesh,
422
            PartitionSpec(*colwise_scale_inv_spec),
423
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
424
        )
425

426
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
427
428
429
430
431
432
433
434
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
435

436
437
438
439
440
441
442
443
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
444
            ) = BaseDBiasQuantizePrimitive.impl(
445
446
447
448
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
449
450
                q_layout=q_layout,
                flatten_axis=flatten_axis,
451
452
453
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
454
            )
455

456
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
474
475
476

        return mesh, sharded_impl, out_shardings, arg_shardings

477
478
479
480
481
482
483
484
485
486
487
488
489
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
490
        del out_dtype, scale_dtype, is_outer, mesh, result_types
491
492

        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
493
            len(value_types[0].shape),
494
            unique_var="BaseDBiasQuantizePrimitive_i",
495
            flatten_axis=flatten_axis,
496
497
498
499
500
501
502
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
503
            if ScalingMode(scaling_mode).is_tensor_scaling():
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
        amax = ("m",)

        return SdyShardingRule(
            (x_axes, ("…1",)),
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
        )

520

521
522
523
524
525
526
527
528
529
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
530
531


532
533
534
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
535
536
    if quantizer is None:
        return x
537
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
538
539


540
541
542
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
    assert flatten_axis < 0
    dtype = dtype or dx.dtype
543
    dbias = jnp.sum(
544
545
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim + flatten_axis)),
546
547
        keepdims=False,
    )
548
    return dbias.astype(dtype)
549
550
551
552
553
554


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
555
    flatten_axis: int = -1,
556
557
558
):
    if quantizer is None:
        return x, None
559
560
561
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
562
563
564
    )


565
def _quantize_dbias_impl(
566
    x: jnp.ndarray,
567
568
569
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
570
    flatten_axis: int = -1,
571
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
572
573
574
575
    """
    Cast wrapper
    Return FP8 tensor
    """
576
577
578
579
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

580
581
    dq_dtype = dq_dtype or x.dtype

582
583
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if not PrimitiveClass.enabled():
584
585
586
587
588
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
589
                flatten_axis=flatten_axis,
590
            )
591
592
593
594
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
595
596

    # TE/common doesn't support colwise only quantization yet
597
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
598
599
600
601
602
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
603
                flatten_axis=flatten_axis,
604
            )
605
606
607
608
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
609
610
611
612
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
613
        out, _ = _quantize_dbias_impl(
614
615
616
617
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
618
            flatten_axis=flatten_axis,
619
        )
620
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
621
622
623
624
        return out, dbias

    if quantizer is None:
        if is_dbias:
625
            return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
626
627
        return x, None

628
629
630
631
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
632
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
633
634
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)

635
636
637
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

638
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
639
    # It is faster to use 1x quantization for tensor scaling
640
641
642
643
644
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
645
646
647
648
649

    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

650
651
652
653
654
655
656
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
657
    ) = PrimitiveClass.outer_primitive.bind(
658
659
660
661
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
662
        q_layout=q_layout.value,
663
        flatten_axis=flatten_axis,
664
665
666
667
668
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
669
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
670
671
        colwise_scale_inv = rowwise_scale_inv

672
673
674
675
676
677
678
679
680
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

681
682
683
684
685
686
687
688
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
689
690
691
692
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
693
    )
694
    return out, dbias.astype(dq_dtype)
695
696
697
698
699


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
700
    flatten_axis: int = -1,
701
702
703
704
705
706
707
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
708
709
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
710
711
712
713

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
714
    out, _ = _quantize_dbias_impl(
715
716
        x,
        quantizer=quantizer,
717
        flatten_axis=flatten_axis,
718
719
720
721
722
723
724
725
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
726
    flatten_axis: int = -1,
727
728
729
730
731
732
733
734
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
735
736
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
737
738
739
740
741
742
743
744

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
745
746
    return _quantize_dbias_impl(
        dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
747
    )
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
842
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
        return x

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
976
977
978
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1051
1052
1053
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1054
1055
1056
1057
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias