quantization.py 34.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import transformer_engine_jax
17
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
22
    te_dtype_to_jax_dtype,
23
    jax_dtype_to_te_dtype,
24
25
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
26
    get_min_device_compute_capability,
27
    NamedSharding,
28
)
29
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
30
from ..quantize import (
31
32
33
34
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
35
    Quantizer,
36
    GroupedQuantizer,
37
38
39
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
40
    NoScaleTensor,
41
)
42
43


44
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
45
46


47
class BaseDBiasQuantizePrimitive(BasePrimitive):
48
    """
49
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
50
51
    """

52
    name = "te_dbias_quantize_ffi"
53
    multiple_results = True
54
55
56
57
58
59
60
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
61
        9,
62
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
63
64
65
66
    inner_primitive = None
    outer_primitive = None

    @staticmethod
67
68
69
    def abstract(
        x_aval,
        scale_aval,
70
        amax_aval,
71
72
73
        *,
        out_dtype,
        scaling_mode,
74
75
        q_layout,
        flatten_axis,
76
77
78
79
        scale_dtype,
        is_dbias,
        is_outer,
    ):
80
        """
81
        te_dbias_quantize_p abstract
82
83
84
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
85
        out_shape = x_aval.shape
86
        assert scale_aval is None or scale_aval.dtype == jnp.float32
87

88
89
90
91
92
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
93

94
        updated_amax_aval = amax_aval
95
96
97

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
98
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
99

100
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
101
            if ScalingMode(scaling_mode).is_tensor_scaling():
102
103
104
105
106
107
108
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
109
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
110
111
112
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
113
114

        if is_dbias:
115
116
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
117
118
119
120
121
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
122
123
124
125
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
126
            )
127
128
129
130
131
132
133
134
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
135
136
137
138
139
140
141
142
143
144

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
145
146

    @staticmethod
147
    def outer_abstract(*args, **kwargs):
148
        """
149
        te_dbias_quantize_p outer primitive abstract
150
        """
151
152
153
154
155
156
157
158
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
159
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
160
161
162
163
164
165
166
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
167
        amax,
168
169
170
        *,
        out_dtype,
        scaling_mode,
171
172
        q_layout,
        flatten_axis,
173
174
175
176
177
178
179
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
180
        del out_dtype, scale_dtype, is_outer
181
        x_aval, scale_aval, amax_aval = ctx.avals_in
182
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
183
184
185
186
187
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
188
189
190
            ctx,
            x,
            scale,
191
            amax,
192
            scaling_mode=scaling_mode.value,
193
194
            q_layout=q_layout,
            flatten_axis=flatten_axis,
195
196
            is_dbias=is_dbias,
        )
197
198

    @staticmethod
199
200
201
    def impl(
        x,
        scale,
202
        amax,
203
204
        out_dtype,
        scaling_mode,
205
206
        q_layout,
        flatten_axis,
207
208
209
210
        scale_dtype,
        is_dbias,
        is_outer,
    ):
211
        """
212
        te_dbias_quantize_p implementation
213
        """
214
        del is_outer
215
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
216
217
218
219
220
221
222
223
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
224
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
225
226
            x,
            scale,
227
            amax,
228
229
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
230
231
            q_layout=q_layout,
            flatten_axis=flatten_axis,
232
233
234
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
235
        )
236
237
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
238
239
240
241
242
243
244
245
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
246
247
248
249
250
251
252
253
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
254
255

    @staticmethod
256
257
258
259
260
261
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
262
263
        q_layout,
        flatten_axis,
264
265
266
267
268
269
270
271
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
272
        check_valid_batch_dims(batch_dims)
273
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
274
275
        x, scale, amax = batched_args
        x_bdim, scale_bdim, amax_bdim = batch_dims
276

277
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
278
        return (
279
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
280
281
                x,
                scale,
282
                amax,
283
284
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
285
286
                q_layout=q_layout,
                flatten_axis=flatten_axis,
287
288
289
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
290
291
            out_bdims,
        )
292
293

    @staticmethod
294
295
296
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
297
298
        q_layout,
        flatten_axis,
299
300
301
302
303
304
305
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
306
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
307

308
        x_spec = get_padded_spec(arg_infos[0])
309
        amax_spec = get_padded_spec(arg_infos[2])
310
311
        out_sharding = NamedSharding(
            mesh,
312
            PartitionSpec(*x_spec),
313
            desc="BaseDBiasQuantizePrimitive.out_sharding",
314
        )
315
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
316
            if ScalingMode(scaling_mode).is_tensor_scaling():
317
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
318
319
320
321
322
323
324
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
325
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
326
        )
327
328
329

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
330
            mesh,
331
            PartitionSpec(*dbias_spec),
332
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
333
        )
334

335
336
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
337
338
339
340
341
342
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
343
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
344
        )
345
        colwise_scale_inv_sharding = NamedSharding(
346
            mesh,
347
            PartitionSpec(*colwise_scale_inv_spec),
348
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
349
        )
350
351
352
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
353

354
355
356
357
358
359
360
361
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
362
363

    @staticmethod
364
365
366
    def partition(
        out_dtype,
        scaling_mode,
367
368
        q_layout,
        flatten_axis,
369
370
371
372
373
374
375
376
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
377

378
        x_spec = get_padded_spec(arg_infos[0])
379
        amax_spec = get_padded_spec(arg_infos[2])
380
381
        out_sharding = NamedSharding(
            mesh,
382
            PartitionSpec(*x_spec),
383
            desc="BaseDBiasQuantizePrimitive.out_sharding",
384
        )
385
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
386
            if ScalingMode(scaling_mode).is_tensor_scaling():
387
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
388
389
390
391
392
393
394
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
395
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
396
        )
397
398
399

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
400
            mesh,
401
            PartitionSpec(*dbias_spec),
402
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
403
        )
404

405
406
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
407
408
409
410
411
412
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
413
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
414
        )
415
        amax_sharding = NamedSharding(
416
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
417
        )
418
        colwise_scale_inv_sharding = NamedSharding(
419
            mesh,
420
            PartitionSpec(*colwise_scale_inv_spec),
421
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
422
        )
423

424
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
425
426
427
428
429
430
431
432
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
433

434
        def sharded_impl(x, scale, amax):
435
436
437
438
439
440
441
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
442
            ) = BaseDBiasQuantizePrimitive.impl(
443
444
                x,
                scale,
445
                amax,
446
447
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
448
449
                q_layout=q_layout,
                flatten_axis=flatten_axis,
450
451
452
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
453
            )
454

455
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
473
474
475

        return mesh, sharded_impl, out_shardings, arg_shardings

476
477
478
479
480
481
482
483
484
485
486
487
488
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
489
        del out_dtype, scale_dtype, is_outer, mesh, result_types
490

Alp Dener's avatar
Alp Dener committed
491
        prefix = "BaseDBiasQuantizePrimitive_"
492
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
493
            len(value_types[0].shape),
Alp Dener's avatar
Alp Dener committed
494
            unique_var=prefix + "x",
495
            flatten_axis=flatten_axis,
496
497
498
499
500
501
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
Alp Dener's avatar
Alp Dener committed
502
        colwise_out = (prefix + "out_colwise",)
503
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
504
            if ScalingMode(scaling_mode).is_tensor_scaling():
505
506
507
508
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
509
510
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
511
512

        return SdyShardingRule(
513
            (x_axes, ("…1",), amax),
514
515
516
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
        )

517

518
519
520
521
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
522
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
523
524
525


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
526
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
527
528


529
530
531
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
532
    if quantizer is None:
533
534
535
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
536
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
537
538


539
540
541
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
542
543
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
544
    dtype = dtype or dx.dtype
545
    dbias = jnp.sum(
546
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
547
        axis=tuple(range(sum_axis)),
548
549
        keepdims=False,
    )
550
    return dbias.astype(dtype)
551
552
553
554
555
556


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
557
    flatten_axis: int = -1,
558
559
):
    if quantizer is None:
560
561
562
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
563
564
565
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
566
567
568
    )


569
def _quantize_dbias_impl(
570
    x: Union[jnp.ndarray, NoScaleTensor],
571
572
573
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
574
    flatten_axis: int = -1,
575
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
576
577
578
579
    """
    Cast wrapper
    Return FP8 tensor
    """
580
581
582
583
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

584
585
586
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
587
    # Early-exit for non-quantized call
588
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
589
590
    if quantizer is None:
        dbias = None
591
        if is_dbias:
592
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
593
        return x, dbias
594

Alp Dener's avatar
Alp Dener committed
595
596
597
598
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
599
600
601
602
603
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
604
                flatten_axis=flatten_axis,
605
            )
606
607
608
609
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
610

Alp Dener's avatar
Alp Dener committed
611
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
612
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
613
        out, _ = _quantize_dbias_impl(
614
615
616
617
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
618
            flatten_axis=flatten_axis,
619
        )
620
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
621
622
        return out, dbias

Alp Dener's avatar
Alp Dener committed
623
    scale = jnp.empty((), jnp.float32)
624
    amax = None
625
626
627
628
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
629
630
631
        amax = x.amax
        if amax is None:
            amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,))
632
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
633
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
634
635
        scale = quantizer.scale

636
    # Make sure amax is init with zero
637
638
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
639

640
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
641
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
642
643
644
645
646
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
647
648
649
650
    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

651
652
653
654
655
656
657
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
658
    ) = PrimitiveClass.outer_primitive.bind(
659
        x.data,
660
        scale,
661
        amax,
662
663
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
664
        q_layout=q_layout.value,
665
        flatten_axis=flatten_axis,
666
667
668
669
670
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
671
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
672
673
        colwise_scale_inv = rowwise_scale_inv

674
675
676
677
678
679
680
681
682
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

683
684
685
686
687
688
689
690
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
691
692
693
694
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
695
    )
696
    return out, dbias.astype(dq_dtype)
697
698
699


def quantize(
700
    x: Union[jnp.ndarray, NoScaleTensor],
701
    quantizer: Quantizer,
702
    flatten_axis: int = -1,
703
704
705
706
707
708
709
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
710
711
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
712
            is None.
713
714
715
716

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
717
    out, _ = _quantize_dbias_impl(
718
719
        x,
        quantizer=quantizer,
720
        flatten_axis=flatten_axis,
721
722
723
724
725
    )
    return out


def quantize_dbias(
726
    dz: Union[jnp.ndarray, NoScaleTensor],
727
728
    quantizer: Quantizer,
    is_dbias: bool = True,
729
    flatten_axis: int = -1,
730
731
732
733
734
735
736
737
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
738
739
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
740
741
742
743
744
745
746
747

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
748
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
749
750
751
752
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
753
    )
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
848
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
927
    amax: jnp.ndarray = None,
928
929
930
931
932
933
934
935
936
937
938
939
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
940
        amax: The amax of x; if None, it is auto-generated. (default: None)
941
942
943
944
945
946
947
948
949
950
951
952
953
954
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
955
956
957
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
985
986
987
988
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
989
990
991
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1064
1065
1066
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1067
1068
1069
1070
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias