quantization.py 34.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9
from packaging import version
10

11
import jax
12
13
import jax.numpy as jnp
from jax import dtypes
14
from jax.experimental.custom_partitioning import SdyShardingRule
15
from jax.sharding import PartitionSpec
16

17
import transformer_engine_jax
18
19
20
21
22

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
23
    te_dtype_to_jax_dtype,
24
    jax_dtype_to_te_dtype,
25
26
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
27
    get_min_device_compute_capability,
28
    NamedSharding,
29
)
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
31
from ..quantize import (
32
33
34
35
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
36
    Quantizer,
37
    GroupedQuantizer,
38
39
40
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
41
    NoScaleTensor,
42
)
43

44
45
46
47
48
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

49

50
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
51
52


53
class BaseDBiasQuantizePrimitive(BasePrimitive):
54
    """
55
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
56
57
    """

58
    name = "te_dbias_quantize_ffi"
59
    multiple_results = True
60
61
62
63
64
65
66
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
67
        9,
68
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
69
70
71
72
    inner_primitive = None
    outer_primitive = None

    @staticmethod
73
74
75
    def abstract(
        x_aval,
        scale_aval,
76
        amax_aval,
77
78
79
        *,
        out_dtype,
        scaling_mode,
80
81
        q_layout,
        flatten_axis,
82
83
84
85
        scale_dtype,
        is_dbias,
        is_outer,
    ):
86
        """
87
        te_dbias_quantize_p abstract
88
89
90
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
91
        out_shape = x_aval.shape
92
        assert scale_aval is None or scale_aval.dtype == jnp.float32
93

94
95
96
97
98
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
99

100
        updated_amax_aval = amax_aval
101
102
103

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
104
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
105

106
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
107
            if ScalingMode(scaling_mode).is_tensor_scaling():
108
109
110
111
112
113
114
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
115
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
116
117
118
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
119
120

        if is_dbias:
121
122
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
123
124
125
126
127
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
128
129
130
131
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
132
            )
133
134
135
136
137
138
139
140
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
141
142
143
144
145
146
147
148
149
150

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
151
152

    @staticmethod
153
    def outer_abstract(*args, **kwargs):
154
        """
155
        te_dbias_quantize_p outer primitive abstract
156
        """
157
158
159
160
161
162
163
164
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
165
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
166
167
168
169
170
171
172
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
173
        amax,
174
175
176
        *,
        out_dtype,
        scaling_mode,
177
178
        q_layout,
        flatten_axis,
179
180
181
182
183
184
185
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
186
        del out_dtype, scale_dtype, is_outer
187
        x_aval, scale_aval, amax_aval = ctx.avals_in
188
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
189
190
191
192
193
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
194
195
196
            ctx,
            x,
            scale,
197
            amax,
198
            scaling_mode=scaling_mode.value,
199
200
            q_layout=q_layout,
            flatten_axis=flatten_axis,
201
202
            is_dbias=is_dbias,
        )
203
204

    @staticmethod
205
206
207
    def impl(
        x,
        scale,
208
        amax,
209
210
        out_dtype,
        scaling_mode,
211
212
        q_layout,
        flatten_axis,
213
214
215
216
        scale_dtype,
        is_dbias,
        is_outer,
    ):
217
        """
218
        te_dbias_quantize_p implementation
219
        """
220
        del is_outer
221
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
222
223
224
225
226
227
228
229
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
230
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
231
232
            x,
            scale,
233
            amax,
234
235
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
236
237
            q_layout=q_layout,
            flatten_axis=flatten_axis,
238
239
240
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
241
        )
242
243
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
244
245
246
247
248
249
250
251
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
252
253
254
255
256
257
258
259
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
260
261

    @staticmethod
262
263
264
265
266
267
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
268
269
        q_layout,
        flatten_axis,
270
271
272
273
274
275
276
277
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
278
        check_valid_batch_dims(batch_dims)
279
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
280
281
        x, scale, amax = batched_args
        x_bdim, scale_bdim, amax_bdim = batch_dims
282

283
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
284
        return (
285
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
286
287
                x,
                scale,
288
                amax,
289
290
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
291
292
                q_layout=q_layout,
                flatten_axis=flatten_axis,
293
294
295
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
296
297
            out_bdims,
        )
298
299

    @staticmethod
300
301
302
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
303
304
        q_layout,
        flatten_axis,
305
306
307
308
309
310
311
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
312
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
313

314
        x_spec = get_padded_spec(arg_infos[0])
315
        amax_spec = get_padded_spec(arg_infos[2])
316
317
        out_sharding = NamedSharding(
            mesh,
318
            PartitionSpec(*x_spec),
319
            desc="BaseDBiasQuantizePrimitive.out_sharding",
320
        )
321
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
322
            if ScalingMode(scaling_mode).is_tensor_scaling():
323
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
324
325
326
327
328
329
330
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
331
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
332
        )
333
334
335

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
336
            mesh,
337
            PartitionSpec(*dbias_spec),
338
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
339
        )
340

341
342
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
343
344
345
346
347
348
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
349
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
350
        )
351
        colwise_scale_inv_sharding = NamedSharding(
352
            mesh,
353
            PartitionSpec(*colwise_scale_inv_spec),
354
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
355
        )
356
357
358
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
359

360
361
362
363
364
365
366
367
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
368
369

    @staticmethod
370
371
372
    def partition(
        out_dtype,
        scaling_mode,
373
374
        q_layout,
        flatten_axis,
375
376
377
378
379
380
381
382
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
383

384
        x_spec = get_padded_spec(arg_infos[0])
385
        amax_spec = get_padded_spec(arg_infos[2])
386
387
        out_sharding = NamedSharding(
            mesh,
388
            PartitionSpec(*x_spec),
389
            desc="BaseDBiasQuantizePrimitive.out_sharding",
390
        )
391
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
392
            if ScalingMode(scaling_mode).is_tensor_scaling():
393
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
394
395
396
397
398
399
400
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
401
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
402
        )
403
404
405

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
406
            mesh,
407
            PartitionSpec(*dbias_spec),
408
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
409
        )
410

411
412
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
413
414
415
416
417
418
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
419
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
420
        )
421
        amax_sharding = NamedSharding(
422
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
423
        )
424
        colwise_scale_inv_sharding = NamedSharding(
425
            mesh,
426
            PartitionSpec(*colwise_scale_inv_spec),
427
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
428
        )
429

430
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
431
432
433
434
435
436
437
438
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
439

440
        def sharded_impl(x, scale, amax):
441
442
443
444
445
446
447
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
448
            ) = BaseDBiasQuantizePrimitive.impl(
449
450
                x,
                scale,
451
                amax,
452
453
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
454
455
                q_layout=q_layout,
                flatten_axis=flatten_axis,
456
457
458
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
459
            )
460

461
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
479
480
481

        return mesh, sharded_impl, out_shardings, arg_shardings

482
483
484
485
486
487
488
489
490
491
492
493
494
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
495
        del out_dtype, scale_dtype, is_outer, mesh, result_types
496

Alp Dener's avatar
Alp Dener committed
497
        prefix = "BaseDBiasQuantizePrimitive_"
498
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
499
            len(value_types[0].shape),
Alp Dener's avatar
Alp Dener committed
500
            unique_var=prefix + "x",
501
            flatten_axis=flatten_axis,
502
503
504
505
506
507
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
Alp Dener's avatar
Alp Dener committed
508
        colwise_out = (prefix + "out_colwise",)
509
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
510
            if ScalingMode(scaling_mode).is_tensor_scaling():
511
512
513
514
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
515
516
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
517
518

        return SdyShardingRule(
519
            (x_axes, ("…1",), amax),
520
521
522
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
        )

523

524
525
526
527
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
528
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
529
530
531


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
532
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
533
534


535
536
537
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
538
    if quantizer is None:
539
540
541
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
542
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
543
544


545
546
547
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
548
549
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
550
    dtype = dtype or dx.dtype
551
    dbias = jnp.sum(
552
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
553
        axis=tuple(range(sum_axis)),
554
555
        keepdims=False,
    )
556
    return dbias.astype(dtype)
557
558
559
560
561
562


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
563
    flatten_axis: int = -1,
564
565
):
    if quantizer is None:
566
567
568
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
569
570
571
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
572
573
574
    )


575
def _quantize_dbias_impl(
576
    x: Union[jnp.ndarray, NoScaleTensor],
577
578
579
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
580
    flatten_axis: int = -1,
581
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
582
583
584
585
    """
    Cast wrapper
    Return FP8 tensor
    """
586
587
588
589
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

590
591
592
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
593
    # Early-exit for non-quantized call
594
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
595
596
    if quantizer is None:
        dbias = None
597
        if is_dbias:
598
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
599
        return x, dbias
600

Alp Dener's avatar
Alp Dener committed
601
602
603
604
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
605
606
607
608
609
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
610
                flatten_axis=flatten_axis,
611
            )
612
613
614
615
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
616

Alp Dener's avatar
Alp Dener committed
617
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
618
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
619
        out, _ = _quantize_dbias_impl(
620
621
622
623
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
624
            flatten_axis=flatten_axis,
625
        )
626
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
627
628
        return out, dbias

Alp Dener's avatar
Alp Dener committed
629
    scale = jnp.empty((), jnp.float32)
630
    amax = None
631
632
633
634
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
635
636
637
        amax = x.amax
        if amax is None:
            amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,))
638
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
639
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
640
641
        scale = quantizer.scale

642
    # Make sure amax is init with zero
643
644
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
645

646
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
647
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
648
649
650
651
652
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
653
654
655
656
    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

657
658
659
660
661
662
663
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
664
    ) = PrimitiveClass.outer_primitive.bind(
665
        x.data,
666
        scale,
667
        amax,
668
669
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
670
        q_layout=q_layout.value,
671
        flatten_axis=flatten_axis,
672
673
674
675
676
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
677
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
678
679
        colwise_scale_inv = rowwise_scale_inv

680
681
682
683
684
685
686
687
688
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

689
690
691
692
693
694
695
696
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
697
698
699
700
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
701
    )
702
    return out, dbias.astype(dq_dtype)
703
704
705


def quantize(
706
    x: Union[jnp.ndarray, NoScaleTensor],
707
    quantizer: Quantizer,
708
    flatten_axis: int = -1,
709
710
711
712
713
714
715
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
716
717
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
718
            is None.
719
720
721
722

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
723
    out, _ = _quantize_dbias_impl(
724
725
        x,
        quantizer=quantizer,
726
        flatten_axis=flatten_axis,
727
728
729
730
731
    )
    return out


def quantize_dbias(
732
    dz: Union[jnp.ndarray, NoScaleTensor],
733
734
    quantizer: Quantizer,
    is_dbias: bool = True,
735
    flatten_axis: int = -1,
736
737
738
739
740
741
742
743
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
744
745
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
746
747
748
749
750
751
752
753

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
754
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
755
756
757
758
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
759
    )
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
854
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
933
    amax: jnp.ndarray = None,
934
935
936
937
938
939
940
941
942
943
944
945
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
946
        amax: The amax of x; if None, it is auto-generated. (default: None)
947
948
949
950
951
952
953
954
955
956
957
958
959
960
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
961
962
963
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
991
992
993
994
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
995
996
997
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1070
1071
1072
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1073
1074
1075
1076
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias