quantization.py 19.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
from typing import Tuple, Optional
6
from packaging import version
7

8
import jax
9
10
import jax.numpy as jnp
from jax import dtypes
11
from jax.sharding import PartitionSpec
12

13
import transformer_engine_jax
14
15
16
17
18

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
19
    te_dtype_to_jax_dtype,
20
    jax_dtype_to_te_dtype,
21
22
23
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
24
)
25
26
27
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode
28

29
30
31
32
33
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

34

35
__all__ = ["quantize", "quantize_dbias"]
36
37


38
class DBiasQuantizePrimitive(BasePrimitive):
39
    """
40
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
41
42
    """

43
    name = "te_dbias_quantize_ffi"
44
    multiple_results = True
45
46
47
48
49
50
51
52
53
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
54
55
56
57
    inner_primitive = None
    outer_primitive = None

    @staticmethod
58
59
60
61
62
63
64
65
66
67
68
69
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        q_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
70
        """
71
        te_dbias_quantize_p abstract
72
        """
73
        del scale_shapes
74
75
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
76
        assert scale_aval is None or scale_aval.dtype == jnp.float32
77

78
        rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
            rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)

        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)

        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)

        colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)

        dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
            t_shape = multidim_transpose(x_aval.shape)
            if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
                # Don't transpose output for MXFP8
                t_shape = x_aval.shape
            colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
            colwise_scale_inv_aval = jax.core.ShapedArray(
                shape=colwise_scale_inv_shape, dtype=scale_dtype
            )

        if is_dbias:
            gi_hidden_size = x_aval.shape[-1]
            dbias_shape = (gi_hidden_size,)
            dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
            )
            wkspace_aval = x_aval.update(
                shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
            )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
129
130

    @staticmethod
131
    def outer_abstract(*args, **kwargs):
132
        """
133
        te_dbias_quantize_p outer primitive abstract
134
        """
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.abstract(*args, **kwargs)
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        q_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype, scale_shapes, is_outer
        x_aval, scale_aval = ctx.avals_in
165
166
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
167
168
169
170
171
172
173
174
        return ffi.ffi_lowering(DBiasQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            scaling_mode=scaling_mode,
            q_axis=q_axis,
            is_dbias=is_dbias,
        )
175
176

    @staticmethod
177
178
179
180
181
182
183
184
185
186
187
    def impl(
        x,
        scale,
        out_dtype,
        scaling_mode,
        q_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
188
        """
189
        te_dbias_quantize_p implementation
190
        """
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        del is_outer
        assert DBiasQuantizePrimitive.inner_primitive is not None
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
        ) = DBiasQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_axis=q_axis,
            scale_dtype=scale_dtype,
            scale_shapes=scale_shapes,
            is_dbias=is_dbias,
            is_outer=False,
211
        )
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_scale_shape_2x(x.shape, is_padded=False)
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
                scale_inv = jax.lax.slice(
                    scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
                )
            if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
                colwise_scale_inv = jax.lax.slice(
                    colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
                )
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
232
233

    @staticmethod
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        q_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
250
        check_valid_batch_dims(batch_dims)
251
252
253
254
        assert DBiasQuantizePrimitive.outer_primitive is not None
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
255

256
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
257
        return (
258
259
260
261
262
263
264
265
266
267
            DBiasQuantizePrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                q_axis=q_axis,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
            ),
268
269
            out_bdims,
        )
270
271

    @staticmethod
272
273
274
275
276
277
278
279
280
281
282
283
284
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        q_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer)  # Unused.
285
        x_spec = get_padded_spec(arg_infos[0])
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*x_spec[:-1], x_spec[-1]),
            desc="DBiasQuantizePrimitive.out_sharding",
        )
        if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_out_spec = multidim_transpose(x_spec)
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
        scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*get_padded_spec(arg_infos[1])),
            desc="DBiasQuantizePrimitive.scale_inv",
        )
        amax_sharding = scale_inv_sharding.duplicate_with_new_description(
            desc="DBiasQuantizePrimitive.amax_sharding"
        )
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_sharding = NamedSharding(
                mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
            )
        colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
            "DBiasQuantizePrimitive.colwise_scale_inv"
        )
        dbias_sharding = NamedSharding(
            mesh,
            PartitionSpec(x_spec[-1]),
            desc="DBiasQuantizePrimitive.dbias_sharding",
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
331
332

    @staticmethod
333
334
335
336
337
338
339
340
341
342
343
344
345
    def partition(
        out_dtype,
        scaling_mode,
        q_axis,
        scale_dtype,
        scale_shapes,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
346
        x_spec = get_padded_spec(arg_infos[0])
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*x_spec[:-1], x_spec[-1]),
            desc="DBiasQuantizePrimitive.out_sharding",
        )
        if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_out_spec = multidim_transpose(x_spec)
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
            desc="DBiasQuantizePrimitive.colwise_out_sharding",
        )
        scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*get_padded_spec(arg_infos[1])),
            desc="DBiasQuantizePrimitive.scale_inv",
        )
        amax_sharding = scale_inv_sharding.duplicate_with_new_description(
            desc="DBiasQuantizePrimitive.amax_sharding"
        )
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_sharding = NamedSharding(
                mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
            )
        colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
            "DBiasQuantizePrimitive.colwise_scale_inv"
        )
        dbias_sharding = NamedSharding(
            mesh,
            PartitionSpec(x_spec[-1]),
            desc="DBiasQuantizePrimitive.dbias_sharding",
        )
384
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
385
386
387
388
389
390
391
392
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        def sharded_impl(x, scale):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
            ) = DBiasQuantizePrimitive.impl(
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                q_axis=q_axis,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                is_outer=True,
412
            )
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
432
433
434
435

        return mesh, sharded_impl, out_shardings, arg_shardings


436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
register_primitive(DBiasQuantizePrimitive)


def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None):
    if quantizer is None:
        return x
    return quantizer.quantize(x, dq_dtype=dq_dtype)


def _jax_dbias(dx: jnp.ndarray):
    dbias = jnp.sum(
        dx,
        axis=tuple(range(dx.ndim - 1)),
        keepdims=False,
    )
    dbias = dbias.ravel()  # C++ function returns an 1D array for dbias
    return dbias


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
):
    if quantizer is None:
        return x, None
    return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x)
463
464


465
466
467
468
469
470
471
472
473
474
475
476
477
def _jax_dbias(
    dx: jnp.ndarray,
):
    dbias = jnp.sum(
        dx.astype(jnp.float32),
        axis=tuple(range(dx.ndim - 1)),
        keepdims=False,
    )
    dbias = dbias.ravel()  # C++ function returns an 1D array for dbias
    return dbias.astype(dx.dtype)


def _quantize_impl(
478
    x: jnp.ndarray,
479
480
481
482
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
483
484
485
486
    """
    Cast wrapper
    Return FP8 tensor
    """
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

    if not DBiasQuantizePrimitive.enabled():
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
            )
        return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None

    # TE/common doesn't support colwise only quantization yet
    if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
            )
        return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None
    scale = jnp.empty((), jnp.float32)

    # TE/common dbias_quantize does not support 1x on arch < 100
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
        out, _ = _quantize_impl(
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
        )
        dbias = _jax_dbias(x)
        return out, dbias

    if quantizer is None:
        if is_dbias:
            return x, _jax_dbias(x)
        return x, None

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DBiasQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_axis=quantizer.q_axis.value,
        scale_dtype=quantizer.get_scale_dtype(),
        scale_shapes=quantizer.get_scale_shapes(x.shape),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
    if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=dq_dtype if dq_dtype is not None else x.dtype,
        q_axis=quantizer.q_axis,
        layout=quantizer.get_layout(),
    )
    return out, dbias


# TODO(Phuong): do not expose dq_dtype to users
def quantize(
    x: jnp.ndarray,
    quantizer: Quantizer,
    dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        dq_dtype: Optional dtype for dequantization.
            If None, uses the same dtype as the input tensor.

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
    out, _ = _quantize_impl(
        x,
        quantizer=quantizer,
        dq_dtype=dq_dtype,
    )
    return out


# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias(
    dz: jnp.ndarray,
    quantizer: Quantizer,
    is_dbias: bool = True,
    dq_dtype: Optional[jnp.dtype] = None,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
        dq_dtype: Optional dtype for dequantization.
            If None, uses the same dtype as the input tensor.

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
    return _quantize_impl(
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        dq_dtype=dq_dtype,
    )