quantization.py 34.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
import math
9
from packaging import version
10

11
import jax
12
13
import jax.numpy as jnp
from jax import dtypes
14
from jax.experimental.custom_partitioning import SdyShardingRule
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18
19
20
21
22

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
23
    te_dtype_to_jax_dtype,
24
    jax_dtype_to_te_dtype,
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
27
    get_min_device_compute_capability,
28
    NamedSharding,
29
)
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
31
from ..quantize import (
32
33
34
35
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
36
    Quantizer,
37
    GroupedQuantizer,
38
39
40
41
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
)
42

43
44
45
46
47
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

48

49
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
50
51


52
class BaseDBiasQuantizePrimitive(BasePrimitive):
53
    """
54
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
55
56
    """

57
    name = "te_dbias_quantize_ffi"
58
    multiple_results = True
59
60
61
62
63
64
65
66
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
67
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
68
69
70
71
    inner_primitive = None
    outer_primitive = None

    @staticmethod
72
73
74
75
76
77
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
78
79
        q_layout,
        flatten_axis,
80
81
82
83
        scale_dtype,
        is_dbias,
        is_outer,
    ):
84
        """
85
        te_dbias_quantize_p abstract
86
87
88
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
89
        out_shape = x_aval.shape
90
        assert scale_aval is None or scale_aval.dtype == jnp.float32
91

92
93
94
95
96
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
97
98
99
100
101

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
102
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
103

104
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
105
            if ScalingMode(scaling_mode).is_tensor_scaling():
106
107
108
109
110
111
112
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
113
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
114
115
116
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
117
118

        if is_dbias:
119
120
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
121
122
123
124
125
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
126
127
128
129
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
130
            )
131
132
133
134
135
136
137
138
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
139
140
141
142
143
144
145
146
147
148

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
149
150

    @staticmethod
151
    def outer_abstract(*args, **kwargs):
152
        """
153
        te_dbias_quantize_p outer primitive abstract
154
        """
155
156
157
158
159
160
161
162
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
163
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
164
165
166
167
168
169
170
171
172
173
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
174
175
        q_layout,
        flatten_axis,
176
177
178
179
180
181
182
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
183
        del out_dtype, scale_dtype, is_outer
184
        x_aval, scale_aval = ctx.avals_in
185
186
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
187
        return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)(
188
189
190
            ctx,
            x,
            scale,
191
            scaling_mode=scaling_mode.value,
192
193
            q_layout=q_layout,
            flatten_axis=flatten_axis,
194
195
            is_dbias=is_dbias,
        )
196
197

    @staticmethod
198
199
200
201
202
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
203
204
        q_layout,
        flatten_axis,
205
206
207
208
        scale_dtype,
        is_dbias,
        is_outer,
    ):
209
        """
210
        te_dbias_quantize_p implementation
211
        """
212
        del is_outer
213
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
214
215
216
217
218
219
220
221
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
222
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
223
224
225
226
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
227
228
            q_layout=q_layout,
            flatten_axis=flatten_axis,
229
230
231
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
232
        )
233
234
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
235
236
237
238
239
240
241
242
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
243
244
245
246
247
248
249
250
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
251
252

    @staticmethod
253
254
255
256
257
258
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
259
260
        q_layout,
        flatten_axis,
261
262
263
264
265
266
267
268
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
269
        check_valid_batch_dims(batch_dims)
270
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
271
272
273
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
274

275
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
276
        return (
277
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
278
279
280
281
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
282
283
                q_layout=q_layout,
                flatten_axis=flatten_axis,
284
285
286
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
287
288
            out_bdims,
        )
289
290

    @staticmethod
291
292
293
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
294
295
        q_layout,
        flatten_axis,
296
297
298
299
300
301
302
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
303
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
304

305
        x_spec = get_padded_spec(arg_infos[0])
306
        scale_spec = get_padded_spec(arg_infos[1])
307
308
        out_sharding = NamedSharding(
            mesh,
309
            PartitionSpec(*x_spec),
310
            desc="BaseDBiasQuantizePrimitive.out_sharding",
311
        )
312
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
313
            if ScalingMode(scaling_mode).is_tensor_scaling():
314
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
315
316
317
318
319
320
321
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
322
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
323
        )
324
325
326

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
327
            mesh,
328
            PartitionSpec(*dbias_spec),
329
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
330
        )
331
332

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
333
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
334
            scale_inv_spec = amax_spec = scale_spec
335
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
336
337
338
339
340
341
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
342
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
343
        )
344
        amax_sharding = NamedSharding(
345
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
346
        )
347
        colwise_scale_inv_sharding = NamedSharding(
348
            mesh,
349
            PartitionSpec(*colwise_scale_inv_spec),
350
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
351
        )
352

353
354
355
356
357
358
359
360
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
361
362

    @staticmethod
363
364
365
    def partition(
        out_dtype,
        scaling_mode,
366
367
        q_layout,
        flatten_axis,
368
369
370
371
372
373
374
375
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
376

377
        x_spec = get_padded_spec(arg_infos[0])
378
        scale_spec = get_padded_spec(arg_infos[1])
379
380
        out_sharding = NamedSharding(
            mesh,
381
            PartitionSpec(*x_spec),
382
            desc="BaseDBiasQuantizePrimitive.out_sharding",
383
        )
384
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
385
            if ScalingMode(scaling_mode).is_tensor_scaling():
386
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
387
388
389
390
391
392
393
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
394
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
395
        )
396
397
398

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
399
            mesh,
400
            PartitionSpec(*dbias_spec),
401
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
402
        )
403
404

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
405
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
406
            scale_inv_spec = amax_spec = scale_spec
407
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
408
409
410
411
412
413
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
414
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
415
        )
416
        amax_sharding = NamedSharding(
417
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
418
        )
419
        colwise_scale_inv_sharding = NamedSharding(
420
            mesh,
421
            PartitionSpec(*colwise_scale_inv_spec),
422
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
423
        )
424

425
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
426
427
428
429
430
431
432
433
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
434

435
436
437
438
439
440
441
442
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
443
            ) = BaseDBiasQuantizePrimitive.impl(
444
445
446
447
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
448
449
                q_layout=q_layout,
                flatten_axis=flatten_axis,
450
451
452
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
453
            )
454

455
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
473
474
475

        return mesh, sharded_impl, out_shardings, arg_shardings

476
477
478
479
480
481
482
483
484
485
486
487
488
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
489
        del out_dtype, scale_dtype, is_outer, mesh, result_types
490

Alp Dener's avatar
Alp Dener committed
491
        prefix = "BaseDBiasQuantizePrimitive_"
492
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
493
            len(value_types[0].shape),
Alp Dener's avatar
Alp Dener committed
494
            unique_var=prefix + "x",
495
            flatten_axis=flatten_axis,
496
497
498
499
500
501
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
Alp Dener's avatar
Alp Dener committed
502
        colwise_out = (prefix + "out_colwise",)
503
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
504
            if ScalingMode(scaling_mode).is_tensor_scaling():
505
506
507
508
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
509
510
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
511
512
513
514
515
516

        return SdyShardingRule(
            (x_axes, ("…1",)),
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
        )

517

518
519
520
521
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
522
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
523
524
525


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
526
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
527
528


529
530
531
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
532
533
    if quantizer is None:
        return x
534
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
535
536


537
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
Alp Dener's avatar
Alp Dener committed
538
539
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
540
    dtype = dtype or dx.dtype
541
    dbias = jnp.sum(
542
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
543
        axis=tuple(range(sum_axis)),
544
545
        keepdims=False,
    )
546
    return dbias.astype(dtype)
547
548
549
550
551
552


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
553
    flatten_axis: int = -1,
554
555
556
):
    if quantizer is None:
        return x, None
557
558
559
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
560
561
562
    )


563
def _quantize_dbias_impl(
564
    x: jnp.ndarray,
565
566
567
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
568
    flatten_axis: int = -1,
Alp Dener's avatar
Alp Dener committed
569
    noop_scaled_tensor: bool = False,
570
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
571
572
573
574
    """
    Cast wrapper
    Return FP8 tensor
    """
575
576
577
578
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

Alp Dener's avatar
Alp Dener committed
579
    # Early-exit for non-quantized call
580
    dq_dtype = dq_dtype or x.dtype
Alp Dener's avatar
Alp Dener committed
581
582
    if quantizer is None:
        dbias = None
583
        if is_dbias:
Alp Dener's avatar
Alp Dener committed
584
585
586
587
588
589
590
591
592
593
            dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
        if noop_scaled_tensor:
            # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
            # always works.
            return (
                ScaledTensorFactory.create_2x(
                    x,
                    None,
                    x,
                    None,
594
                    scaling_mode=ScalingMode.NO_SCALING,
Alp Dener's avatar
Alp Dener committed
595
596
597
598
599
                    dq_dtype=x.dtype,
                    data_layout="NN",
                    flatten_axis=flatten_axis,
                ),
                dbias,
600
            )
Alp Dener's avatar
Alp Dener committed
601
        return x, dbias
602

Alp Dener's avatar
Alp Dener committed
603
604
605
606
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
607
608
609
610
611
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
612
                flatten_axis=flatten_axis,
613
            )
614
615
616
617
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
618

Alp Dener's avatar
Alp Dener committed
619
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
620
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
621
        out, _ = _quantize_dbias_impl(
622
623
624
625
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
626
            flatten_axis=flatten_axis,
627
        )
628
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
629
630
        return out, dbias

Alp Dener's avatar
Alp Dener committed
631
    scale = jnp.empty((), jnp.float32)
632
633
634
635
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
636
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
637
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
638
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
639
640
        scale = quantizer.scale

641
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
642
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
643
644
645
646
647
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
648
649
650
651
    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

652
653
654
655
656
657
658
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
659
    ) = PrimitiveClass.outer_primitive.bind(
660
661
662
663
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
664
        q_layout=q_layout.value,
665
        flatten_axis=flatten_axis,
666
667
668
669
670
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
671
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
672
673
        colwise_scale_inv = rowwise_scale_inv

674
675
676
677
678
679
680
681
682
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

683
684
685
686
687
688
689
690
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
691
692
693
694
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
695
    )
696
    return out, dbias.astype(dq_dtype)
697
698
699
700
701


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
702
    flatten_axis: int = -1,
Alp Dener's avatar
Alp Dener committed
703
    noop_scaled_tensor: bool = False,
704
705
706
707
708
709
710
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
711
712
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
713
714
        noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
            is None.
715
716
717
718

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
719
    out, _ = _quantize_dbias_impl(
720
721
        x,
        quantizer=quantizer,
722
        flatten_axis=flatten_axis,
Alp Dener's avatar
Alp Dener committed
723
        noop_scaled_tensor=noop_scaled_tensor,
724
725
726
727
728
729
730
731
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
732
    flatten_axis: int = -1,
Alp Dener's avatar
Alp Dener committed
733
    noop_scaled_tensor: bool = False,
734
735
736
737
738
739
740
741
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
742
743
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
744
745
        noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
            quantizer is None.
746
747
748
749
750
751
752
753

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
754
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
755
756
757
758
759
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
        noop_scaled_tensor=noop_scaled_tensor,
760
    )
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
855
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
934
    amax: jnp.ndarray = None,
935
936
937
938
939
940
941
942
943
944
945
946
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
947
        amax: The amax of x; if None, it is auto-generated. (default: None)
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
        return x

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
990
991
992
993
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
994
995
996
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1069
1070
1071
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1072
1073
1074
1075
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias