quantization.py 34.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional
8
import math
9
from packaging import version
10

11
import jax
12
13
import jax.numpy as jnp
from jax import dtypes
14
from jax.experimental.custom_partitioning import SdyShardingRule
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18
19
20
21
22

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
23
    te_dtype_to_jax_dtype,
24
    jax_dtype_to_te_dtype,
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
27
    get_min_device_compute_capability,
28
    NamedSharding,
29
)
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
31
from ..quantize import (
32
33
34
35
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
36
    Quantizer,
37
    GroupedQuantizer,
38
39
40
41
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
)
42

43
44
45
46
47
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

48

49
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
50
51


52
class BaseDBiasQuantizePrimitive(BasePrimitive):
53
    """
54
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
55
56
    """

57
    name = "te_dbias_quantize_ffi"
58
    multiple_results = True
59
60
61
62
63
64
65
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
66
67
        9,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval
68
69
70
71
    inner_primitive = None
    outer_primitive = None

    @staticmethod
72
73
74
    def abstract(
        x_aval,
        scale_aval,
75
        amax_aval,
76
77
78
        *,
        out_dtype,
        scaling_mode,
79
80
        q_layout,
        flatten_axis,
81
82
83
84
        scale_dtype,
        is_dbias,
        is_outer,
    ):
85
        """
86
        te_dbias_quantize_p abstract
87
88
89
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
90
        out_shape = x_aval.shape
91
        assert scale_aval is None or scale_aval.dtype == jnp.float32
92

93
94
95
96
97
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
98

99
        updated_amax_aval = amax_aval
100
101
102

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
103
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
104

105
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
106
            if ScalingMode(scaling_mode).is_tensor_scaling():
107
108
109
110
111
112
113
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
114
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
115
116
117
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
118
119

        if is_dbias:
120
121
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
122
123
124
125
126
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
127
128
129
130
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
131
            )
132
133
134
135
136
137
138
139
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
140
141
142
143
144
145
146
147
148
149

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
150
151

    @staticmethod
152
    def outer_abstract(*args, **kwargs):
153
        """
154
        te_dbias_quantize_p outer primitive abstract
155
        """
156
157
158
159
160
161
162
163
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
164
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
165
166
167
168
169
170
171
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
172
        amax,
173
174
175
        *,
        out_dtype,
        scaling_mode,
176
177
        q_layout,
        flatten_axis,
178
179
180
181
182
183
184
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
185
        del out_dtype, scale_dtype, is_outer
186
        x_aval, scale_aval, amax_aval = ctx.avals_in
187
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
188
189
190
191
192
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
193
194
195
            ctx,
            x,
            scale,
196
            amax,
197
            scaling_mode=scaling_mode.value,
198
199
            q_layout=q_layout,
            flatten_axis=flatten_axis,
200
201
            is_dbias=is_dbias,
        )
202
203

    @staticmethod
204
205
206
    def impl(
        x,
        scale,
207
        amax,
208
209
        out_dtype,
        scaling_mode,
210
211
        q_layout,
        flatten_axis,
212
213
214
215
        scale_dtype,
        is_dbias,
        is_outer,
    ):
216
        """
217
        te_dbias_quantize_p implementation
218
        """
219
        del is_outer
220
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
221
222
223
224
225
226
227
228
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
229
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
230
231
            x,
            scale,
232
            amax,
233
234
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
235
236
            q_layout=q_layout,
            flatten_axis=flatten_axis,
237
238
239
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
240
        )
241
242
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
243
244
245
246
247
248
249
250
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
251
252
253
254
255
256
257
258
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
259
260

    @staticmethod
261
262
263
264
265
266
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
267
268
        q_layout,
        flatten_axis,
269
270
271
272
273
274
275
276
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
277
        check_valid_batch_dims(batch_dims)
278
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
279
280
        x, scale, amax = batched_args
        x_bdim, scale_bdim, amax_bdim = batch_dims
281

282
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
283
        return (
284
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
285
286
                x,
                scale,
287
                amax,
288
289
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
290
291
                q_layout=q_layout,
                flatten_axis=flatten_axis,
292
293
294
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
295
296
            out_bdims,
        )
297
298

    @staticmethod
299
300
301
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
302
303
        q_layout,
        flatten_axis,
304
305
306
307
308
309
310
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
311
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
312

313
        x_spec = get_padded_spec(arg_infos[0])
314
        amax_spec = get_padded_spec(arg_infos[2])
315
316
        out_sharding = NamedSharding(
            mesh,
317
            PartitionSpec(*x_spec),
318
            desc="BaseDBiasQuantizePrimitive.out_sharding",
319
        )
320
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
321
            if ScalingMode(scaling_mode).is_tensor_scaling():
322
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
323
324
325
326
327
328
329
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
330
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
331
        )
332
333
334

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
335
            mesh,
336
            PartitionSpec(*dbias_spec),
337
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
338
        )
339

340
341
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
342
343
344
345
346
347
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
348
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
349
        )
350
        colwise_scale_inv_sharding = NamedSharding(
351
            mesh,
352
            PartitionSpec(*colwise_scale_inv_spec),
353
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
354
        )
355
356
357
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
358

359
360
361
362
363
364
365
366
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
367
368

    @staticmethod
369
370
371
    def partition(
        out_dtype,
        scaling_mode,
372
373
        q_layout,
        flatten_axis,
374
375
376
377
378
379
380
381
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
382

383
        x_spec = get_padded_spec(arg_infos[0])
384
        amax_spec = get_padded_spec(arg_infos[2])
385
386
        out_sharding = NamedSharding(
            mesh,
387
            PartitionSpec(*x_spec),
388
            desc="BaseDBiasQuantizePrimitive.out_sharding",
389
        )
390
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
391
            if ScalingMode(scaling_mode).is_tensor_scaling():
392
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
393
394
395
396
397
398
399
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
400
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
401
        )
402
403
404

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
405
            mesh,
406
            PartitionSpec(*dbias_spec),
407
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
408
        )
409

410
411
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
412
413
414
415
416
417
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
418
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
419
        )
420
        amax_sharding = NamedSharding(
421
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
422
        )
423
        colwise_scale_inv_sharding = NamedSharding(
424
            mesh,
425
            PartitionSpec(*colwise_scale_inv_spec),
426
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
427
        )
428

429
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
430
431
432
433
434
435
436
437
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
438

439
        def sharded_impl(x, scale, amax):
440
441
442
443
444
445
446
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
447
            ) = BaseDBiasQuantizePrimitive.impl(
448
449
                x,
                scale,
450
                amax,
451
452
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
453
454
                q_layout=q_layout,
                flatten_axis=flatten_axis,
455
456
457
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
458
            )
459

460
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
478
479
480

        return mesh, sharded_impl, out_shardings, arg_shardings

481
482
483
484
485
486
487
488
489
490
491
492
493
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
494
        del out_dtype, scale_dtype, is_outer, mesh, result_types
495

Alp Dener's avatar
Alp Dener committed
496
        prefix = "BaseDBiasQuantizePrimitive_"
497
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
498
            len(value_types[0].shape),
Alp Dener's avatar
Alp Dener committed
499
            unique_var=prefix + "x",
500
            flatten_axis=flatten_axis,
501
502
503
504
505
506
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
Alp Dener's avatar
Alp Dener committed
507
        colwise_out = (prefix + "out_colwise",)
508
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
509
            if ScalingMode(scaling_mode).is_tensor_scaling():
510
511
512
513
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
514
515
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
516
517

        return SdyShardingRule(
518
            (x_axes, ("…1",), amax),
519
520
521
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
        )

522

523
524
525
526
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
527
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
528
529
530


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
531
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
532
533


534
535
536
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
537
538
    if quantizer is None:
        return x
539
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
540
541


542
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
Alp Dener's avatar
Alp Dener committed
543
544
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
545
    dtype = dtype or dx.dtype
546
    dbias = jnp.sum(
547
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
548
        axis=tuple(range(sum_axis)),
549
550
        keepdims=False,
    )
551
    return dbias.astype(dtype)
552
553
554
555
556
557


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
558
    flatten_axis: int = -1,
559
560
561
):
    if quantizer is None:
        return x, None
562
563
564
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
565
566
567
    )


568
def _quantize_dbias_impl(
569
    x: jnp.ndarray,
570
571
572
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
573
    flatten_axis: int = -1,
Alp Dener's avatar
Alp Dener committed
574
    noop_scaled_tensor: bool = False,
575
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
576
577
578
579
    """
    Cast wrapper
    Return FP8 tensor
    """
580
581
582
583
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

Alp Dener's avatar
Alp Dener committed
584
    # Early-exit for non-quantized call
585
    dq_dtype = dq_dtype or x.dtype
Alp Dener's avatar
Alp Dener committed
586
587
    if quantizer is None:
        dbias = None
588
        if is_dbias:
Alp Dener's avatar
Alp Dener committed
589
590
591
592
593
594
595
596
597
598
            dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
        if noop_scaled_tensor:
            # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
            # always works.
            return (
                ScaledTensorFactory.create_2x(
                    x,
                    None,
                    x,
                    None,
599
                    scaling_mode=ScalingMode.NO_SCALING,
Alp Dener's avatar
Alp Dener committed
600
601
602
603
604
                    dq_dtype=x.dtype,
                    data_layout="NN",
                    flatten_axis=flatten_axis,
                ),
                dbias,
605
            )
Alp Dener's avatar
Alp Dener committed
606
        return x, dbias
607

Alp Dener's avatar
Alp Dener committed
608
609
610
611
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
612
613
614
615
616
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
617
                flatten_axis=flatten_axis,
618
            )
619
620
621
622
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
623

Alp Dener's avatar
Alp Dener committed
624
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
625
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
626
        out, _ = _quantize_dbias_impl(
627
628
629
630
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
631
            flatten_axis=flatten_axis,
632
        )
633
        dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
634
635
        return out, dbias

Alp Dener's avatar
Alp Dener committed
636
    scale = jnp.empty((), jnp.float32)
637
638
639
640
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
641
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
642
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
643
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
644
645
        scale = quantizer.scale

646
647
648
    # Make sure amax is init with zero
    amax = jnp.zeros((1,), jnp.float32)

649
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
650
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
651
652
653
654
655
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
656
657
658
659
    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

660
661
662
663
664
665
666
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
667
    ) = PrimitiveClass.outer_primitive.bind(
668
669
        x,
        scale,
670
        amax,
671
672
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
673
        q_layout=q_layout.value,
674
        flatten_axis=flatten_axis,
675
676
677
678
679
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
680
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
681
682
        colwise_scale_inv = rowwise_scale_inv

683
684
685
686
687
688
689
690
691
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

692
693
694
695
696
697
698
699
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
700
701
702
703
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
704
    )
705
    return out, dbias.astype(dq_dtype)
706
707
708
709
710


def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
711
    flatten_axis: int = -1,
Alp Dener's avatar
Alp Dener committed
712
    noop_scaled_tensor: bool = False,
713
714
715
716
717
718
719
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
720
721
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
722
723
        noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
            is None.
724
725
726
727

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
728
    out, _ = _quantize_dbias_impl(
729
730
        x,
        quantizer=quantizer,
731
        flatten_axis=flatten_axis,
Alp Dener's avatar
Alp Dener committed
732
        noop_scaled_tensor=noop_scaled_tensor,
733
734
735
736
737
738
739
740
    )
    return out


def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
741
    flatten_axis: int = -1,
Alp Dener's avatar
Alp Dener committed
742
    noop_scaled_tensor: bool = False,
743
744
745
746
747
748
749
750
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
751
752
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
753
754
        noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
            quantizer is None.
755
756
757
758
759
760
761
762

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
763
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
764
765
766
767
768
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
        noop_scaled_tensor=noop_scaled_tensor,
769
    )
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
864
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
943
    amax: jnp.ndarray = None,
944
945
946
947
948
949
950
951
952
953
954
955
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
956
        amax: The amax of x; if None, it is auto-generated. (default: None)
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
        return x

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
999
1000
1001
1002
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
1003
1004
1005
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1078
1079
1080
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1081
1082
1083
1084
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias