quantization.py 38 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
5
6
import operator
from functools import reduce
7
from typing import Tuple, Optional, Union
8
import math
9
10
from enum import Enum

11

12
import jax
13
import jax.numpy as jnp
14
from jax import dtypes, ffi
15
from jax.experimental.custom_partitioning import SdyShardingRule
16
from jax.sharding import PartitionSpec
17

18
import transformer_engine_jax
19
20
21
22
23

from .base import BasePrimitive, register_primitive
from .misc import (
    get_padded_spec,
    check_valid_batch_dims,
24
    te_dtype_to_jax_dtype,
25
    jax_dtype_to_te_dtype,
26
27
    multidim_transpose,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
28
    get_min_device_compute_capability,
29
    NamedSharding,
30
)
31
32
33
34
35
36
from ..sharding import (
    all_reduce_max_along_all_axes_except_PP,
    all_reduce_sum_along_dp_fsdp,
    global_mesh_resource,
    lax_paral_op,
)
37
from ..quantize import (
38
39
40
41
    ScaledTensor2x,
    ScaledTensor,
    ScaledTensorFactory,
    GroupedScaledTensor1x,
42
    Quantizer,
43
    GroupedQuantizer,
44
45
46
    QuantizeLayout,
    ScalingMode,
    compute_scale_from_amax,
47
    NoScaleTensor,
48
)
49
50


51
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
52
53


54
class BaseDBiasQuantizePrimitive(BasePrimitive):
55
    """
56
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
57
58
    """

59
    name = "te_dbias_quantize_ffi"
60
    multiple_results = True
61
62
63
64
65
66
67
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
68
        9,
69
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
70
71
72
73
    inner_primitive = None
    outer_primitive = None

    @staticmethod
74
75
76
    def abstract(
        x_aval,
        scale_aval,
77
        amax_aval,
78
79
80
        *,
        out_dtype,
        scaling_mode,
81
82
        q_layout,
        flatten_axis,
83
84
85
86
        scale_dtype,
        is_dbias,
        is_outer,
    ):
87
        """
88
        te_dbias_quantize_p abstract
89
90
91
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
92
        out_shape = x_aval.shape
93
        assert scale_aval is None or scale_aval.dtype == jnp.float32
94

95
96
97
98
99
        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
100

101
        updated_amax_aval = amax_aval
102
103
104

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
105
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
106

107
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
108
            if ScalingMode(scaling_mode).is_tensor_scaling():
109
110
111
112
113
114
115
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
            else:
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
116
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
117
118
119
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
120
121

        if is_dbias:
122
123
            dbias_shape = x_aval.shape[flatten_axis:]
            gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
124
125
126
127
128
            (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
129
130
131
132
                scaling_mode,
                QuantizeLayout(
                    q_layout
                ),  # For now until we have auto-decoding for QuantizeLayout enum
133
            )
134
135
136
137
138
139
140
141
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
142
143
144
145
146
147
148
149
150
151

        return (
            rowwise_out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
152
153

    @staticmethod
154
    def outer_abstract(*args, **kwargs):
155
        """
156
        te_dbias_quantize_p outer primitive abstract
157
        """
158
159
160
161
162
163
164
165
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
166
        ) = BaseDBiasQuantizePrimitive.abstract(*args, **kwargs)
167
168
169
170
171
172
173
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
174
        amax,
175
176
177
        *,
        out_dtype,
        scaling_mode,
178
179
        q_layout,
        flatten_axis,
180
181
182
183
184
185
186
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
187
        del out_dtype, scale_dtype, is_outer
188
        x_aval, scale_aval, amax_aval = ctx.avals_in
189
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
190
191
192
193
194
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDBiasQuantizePrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
195
196
197
            ctx,
            x,
            scale,
198
            amax,
199
            scaling_mode=scaling_mode.value,
200
201
            q_layout=q_layout,
            flatten_axis=flatten_axis,
202
203
            is_dbias=is_dbias,
        )
204
205

    @staticmethod
206
207
208
    def impl(
        x,
        scale,
209
        amax,
210
211
        out_dtype,
        scaling_mode,
212
213
        q_layout,
        flatten_axis,
214
215
216
217
        scale_dtype,
        is_dbias,
        is_outer,
    ):
218
        """
219
        te_dbias_quantize_p implementation
220
        """
221
        del is_outer
222
        assert BaseDBiasQuantizePrimitive.inner_primitive is not None
223
224
225
226
227
228
229
230
        (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
            _,
231
        ) = BaseDBiasQuantizePrimitive.inner_primitive.bind(
232
233
            x,
            scale,
234
            amax,
235
236
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
237
238
            q_layout=q_layout,
            flatten_axis=flatten_axis,
239
240
241
            scale_dtype=scale_dtype,
            is_dbias=is_dbias,
            is_outer=False,
242
        )
243
244
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
245
246
247
248
249
250
251
252
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
253
254
255
256
257
258
259
260
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
261
262

    @staticmethod
263
264
265
266
267
268
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
269
270
        q_layout,
        flatten_axis,
271
272
273
274
275
276
277
278
        scale_dtype,
        is_dbias,
        is_outer,
    ):
        """
        to describe batch rules for vmap
        """
        del is_outer
279
        check_valid_batch_dims(batch_dims)
280
        assert BaseDBiasQuantizePrimitive.outer_primitive is not None
281
282
        x, scale, amax = batched_args
        x_bdim, scale_bdim, amax_bdim = batch_dims
283

284
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
285
        return (
286
            BaseDBiasQuantizePrimitive.outer_primitive.bind(
287
288
                x,
                scale,
289
                amax,
290
291
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
292
293
                q_layout=q_layout,
                flatten_axis=flatten_axis,
294
295
296
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
            ),
297
298
            out_bdims,
        )
299
300

    @staticmethod
301
302
303
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
304
305
        q_layout,
        flatten_axis,
306
307
308
309
310
311
312
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
313
        del (out_dtype, result_infos, scale_dtype, is_outer)  # Unused.
314

315
        x_spec = get_padded_spec(arg_infos[0])
316
        amax_spec = get_padded_spec(arg_infos[2])
317
318
        out_sharding = NamedSharding(
            mesh,
319
            PartitionSpec(*x_spec),
320
            desc="BaseDBiasQuantizePrimitive.out_sharding",
321
        )
322
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
323
            if ScalingMode(scaling_mode).is_tensor_scaling():
324
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
325
326
327
328
329
330
331
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
332
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
333
        )
334
335
336

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
337
            mesh,
338
            PartitionSpec(*dbias_spec),
339
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
340
        )
341

342
343
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
344
345
346
347
348
349
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
350
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
351
        )
352
        colwise_scale_inv_sharding = NamedSharding(
353
            mesh,
354
            PartitionSpec(*colwise_scale_inv_spec),
355
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
356
        )
357
358
359
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
        )
360

361
362
363
364
365
366
367
368
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
369
370

    @staticmethod
371
372
373
    def partition(
        out_dtype,
        scaling_mode,
374
375
        q_layout,
        flatten_axis,
376
377
378
379
380
381
382
383
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
384

385
        x_spec = get_padded_spec(arg_infos[0])
386
        amax_spec = get_padded_spec(arg_infos[2])
387
388
        out_sharding = NamedSharding(
            mesh,
389
            PartitionSpec(*x_spec),
390
            desc="BaseDBiasQuantizePrimitive.out_sharding",
391
        )
392
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
393
            if ScalingMode(scaling_mode).is_tensor_scaling():
394
                colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
395
396
397
398
399
400
401
            else:
                colwise_out_spec = x_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_out_spec),
402
            desc="BaseDBiasQuantizePrimitive.colwise_out_sharding",
403
        )
404
405
406

        dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
407
            mesh,
408
            PartitionSpec(*dbias_spec),
409
            desc="BaseDBiasQuantizePrimitive.dbias_sharding",
410
        )
411

412
413
        scale_inv_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
414
415
416
417
418
419
            scale_inv_spec = x_spec

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_scale_inv_spec = scale_inv_spec

        scale_inv_sharding = NamedSharding(
420
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv"
421
        )
422
        amax_sharding = NamedSharding(
423
            mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax"
424
        )
425
        colwise_scale_inv_sharding = NamedSharding(
426
            mesh,
427
            PartitionSpec(*colwise_scale_inv_spec),
428
            desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
429
        )
430

431
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
432
433
434
435
436
437
438
439
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_sharding,
        )
440

441
        def sharded_impl(x, scale, amax):
442
443
444
445
446
447
448
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_amax,
                local_dbias,
449
            ) = BaseDBiasQuantizePrimitive.impl(
450
451
                x,
                scale,
452
                amax,
453
454
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
455
456
                q_layout=q_layout,
                flatten_axis=flatten_axis,
457
458
459
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                is_outer=True,
460
            )
461

462
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
                global_dbias,
            )
480
481
482

        return mesh, sharded_impl, out_shardings, arg_shardings

483
484
485
486
487
488
489
490
491
492
493
494
495
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        scale_dtype,
        is_dbias,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
496
        del out_dtype, scale_dtype, is_outer, mesh, result_types
497

498
        prefix = "DBiasQuantize_"
499
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
500
            value_types[0].shape,
Alp Dener's avatar
Alp Dener committed
501
            unique_var=prefix + "x",
502
            flatten_axis=flatten_axis,
503
504
505
506
507
508
        )

        x_axes = scale_rules.input_spec
        colwise_scale_inv = scale_rules.colwise_rule

        out = x_axes
Alp Dener's avatar
Alp Dener committed
509
        colwise_out = (prefix + "out_colwise",)
510
        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
511
            if ScalingMode(scaling_mode).is_tensor_scaling():
512
513
514
515
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
            else:
                colwise_out = x_axes

Alp Dener's avatar
Alp Dener committed
516
517
        dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
518
519

        return SdyShardingRule(
520
            (x_axes, ("…1",), amax),
521
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
522
            **scale_rules.factor_sizes,
523
524
        )

525

526
527
528
529
register_primitive(BaseDBiasQuantizePrimitive)


class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
530
    """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
531
532
533


class QuantizePrimitive(BaseDBiasQuantizePrimitive):
534
    """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
535
536


537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
class AmaxScope(Enum):
    """
    Amax Scope Enum
    """

    LOCAL = 1
    TPSP = 2
    FSDP = 3


class AmaxCalculationPrimitive(BasePrimitive):
    """
    Amax Calculation Primitive with custom_partitioning
    """

    name = "jax_local_amax"
    multiple_results = False
    impl_static_args = (1,)  # amax_scope
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        *,
        amax_scope,
    ):
        """
        amax calcuation abstract
        """
        del amax_scope

        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        return out_aval

    @staticmethod
    def impl(
        x,
        amax_scope,
    ):
        """
        amax calcuation implementation
        """
        del amax_scope
        amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
        return amax

    @staticmethod
    def infer_sharding_from_operands(
        amax_scope,
        mesh,
        arg_infos,
        result_infos,
    ):
        """
        amax calcuation infer_sharding_from_operands
        """
        del (amax_scope, arg_infos, result_infos)  # Unused.
        amax_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="AmaxCalculationPrimitive.out_sharding",
        )
        return amax_sharding

    @staticmethod
    def partition(
        amax_scope,
        mesh,
        arg_infos,
        result_infos,
    ):
        """
        amax calcuation partition
        """
        del result_infos

        amax_sharding = NamedSharding(
            mesh,
            PartitionSpec(None),
            desc="AmaxCalculationPrimitive.out_sharding",
        )

        def sharded_impl(x):
            amax = AmaxCalculationPrimitive.impl(
                x,
                amax_scope=amax_scope,
            )
            if amax_scope is AmaxScope.TPSP:  # Run AR across TP/SP
                gmesh = global_mesh_resource()
                amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh)
                amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)

            if amax_scope is AmaxScope.FSDP:  # Run AR across FSDP
                gmesh = global_mesh_resource()
                amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)

            return amax

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        return mesh, sharded_impl, amax_sharding, arg_shardings

    @staticmethod
    def shardy_sharding_rule(amax_scope, mesh, value_types, result_types):
        """
        amax calcuation shardy_sharding_rule
        """
        del amax_scope, mesh, result_types
        prefix = "AmaxCal"
        input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
        output_spec = (f"{prefix}_amax",)
        return SdyShardingRule((input_spec,), (output_spec,))


register_primitive(AmaxCalculationPrimitive, outer_only=True)


657
658
659
def _jax_quantize(
    x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
660
    if quantizer is None:
661
662
663
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
664
    return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
665
666


667
668
669
def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1):
    if isinstance(dx, NoScaleTensor):
        dx = dx.data
Alp Dener's avatar
Alp Dener committed
670
671
    sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
    assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
672
    dtype = dtype or dx.dtype
673
    dbias = jnp.sum(
674
        dx.astype(jnp.float32),
Alp Dener's avatar
Alp Dener committed
675
        axis=tuple(range(sum_axis)),
676
677
        keepdims=False,
    )
678
    return dbias.astype(dtype)
679
680
681
682
683
684


def _jax_quantize_dbias(
    x,
    quantizer: Quantizer = None,
    dq_dtype: Optional[jnp.dtype] = None,
685
    flatten_axis: int = -1,
686
687
):
    if quantizer is None:
688
689
690
        if isinstance(x, NoScaleTensor):
            return x, None
        return NoScaleTensor(data=x, amax=None), None
691
692
693
    return (
        quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
        _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
694
695
696
    )


697
def _quantize_dbias_impl(
698
    x: Union[jnp.ndarray, NoScaleTensor],
699
700
701
    quantizer: Quantizer,
    is_dbias: bool = False,
    dq_dtype: Optional[jnp.dtype] = None,
702
    flatten_axis: int = -1,
703
    amax_scope: AmaxScope = AmaxScope.LOCAL,  # Only works when using current-scaling
704
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
705
706
707
708
    """
    Cast wrapper
    Return FP8 tensor
    """
709
710
711
712
    assert (dq_dtype is None) or (
        quantizer is not None
    ), "quantizer must be provided if dq_dtype is provided"

713
714
715
    if isinstance(x, jnp.ndarray):
        x = NoScaleTensor(data=x, amax=None)

Alp Dener's avatar
Alp Dener committed
716
    # Early-exit for non-quantized call
717
    dq_dtype = dq_dtype or x.data.dtype
Alp Dener's avatar
Alp Dener committed
718
719
    if quantizer is None:
        dbias = None
720
        if is_dbias:
721
            dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
Alp Dener's avatar
Alp Dener committed
722
        return x, dbias
723

Alp Dener's avatar
Alp Dener committed
724
725
726
727
    # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
    # fall back on the native-JAX quantize implementation
    PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
    if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
728
729
730
731
732
        if is_dbias:
            return _jax_quantize_dbias(
                x,
                quantizer=quantizer,
                dq_dtype=dq_dtype,
733
                flatten_axis=flatten_axis,
734
            )
735
736
737
738
        return (
            _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
            None,
        )
739

Alp Dener's avatar
Alp Dener committed
740
    # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
741
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
742
        out, _ = _quantize_dbias_impl(
743
744
745
746
            x=x,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=dq_dtype,
747
            flatten_axis=flatten_axis,
748
        )
749
        dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
750
751
        return out, dbias

Alp Dener's avatar
Alp Dener committed
752
    scale = jnp.empty((), jnp.float32)
753
    amax = None
754
755
756
757
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Globally reduce amax across all devices for current scaling so we have a single global scale.
        # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
        # until the tensor is dequantized (e.g. in the GEMM).
758
759
        amax = x.amax
        if amax is None:
760
761
762
763
            amax = AmaxCalculationPrimitive.outer_primitive.bind(
                x.data,
                amax_scope=amax_scope,
            )
764
        scale = compute_scale_from_amax(amax, quantizer.q_dtype)
Alp Dener's avatar
Alp Dener committed
765
    elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
766
767
        scale = quantizer.scale

768
    # Make sure amax is init with zero
769
770
    if amax is None:
        amax = jnp.zeros((1,), jnp.float32)
771

772
    # It is faster to use 1x quantization for tensor scaling
Alp Dener's avatar
Alp Dener committed
773
    is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
774
775
776
777
778
    force_1x_quantization = (
        quantizer.scaling_mode.is_tensor_scaling()
        and quantizer.is_2x2x()
        and is_1x_kernel_supported
    )
779
780
781
782
    q_layout = quantizer.q_layout
    if force_1x_quantization:
        q_layout = QuantizeLayout.ROWWISE

783
784
785
786
787
788
789
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
790
    ) = PrimitiveClass.outer_primitive.bind(
791
        x.data,
792
        scale,
793
        amax,
794
795
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
796
        q_layout=q_layout.value,
797
        flatten_axis=flatten_axis,
798
799
800
801
802
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        is_outer=True,
    )
    # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
803
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
804
805
        colwise_scale_inv = rowwise_scale_inv

806
807
808
809
810
811
812
813
814
        if q_layout == QuantizeLayout.ROWWISE:
            # Quantizer requires 2x quantization, but we are using 1x quantization
            # for performance reasons, so we need to generate the colwise data in JAX
            if flatten_axis < 0:
                flatten_axis += x.ndim
            colwise_casted_output = jnp.transpose(
                rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
            )

815
816
817
818
819
820
821
822
    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
823
824
825
826
        dq_dtype=dq_dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
827
    )
828
    return out, dbias.astype(dq_dtype)
829
830
831


def quantize(
832
    x: Union[jnp.ndarray, NoScaleTensor],
833
    quantizer: Quantizer,
834
    flatten_axis: int = -1,
835
    amax_scope: AmaxScope = AmaxScope.LOCAL,
836
837
838
839
840
841
842
) -> Tuple[ScaledTensor]:
    """Quantize input tensor according to the quantizer.

    Args:
        x: Input tensor to be quantized.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
843
844
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
Alp Dener's avatar
Alp Dener committed
845
            is None.
846
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
847
848
849
850

    Returns:
        A ScaledTensor containing the quantized input tensor.
    """
851
    out, _ = _quantize_dbias_impl(
852
853
        x,
        quantizer=quantizer,
854
        flatten_axis=flatten_axis,
855
        amax_scope=amax_scope,
856
857
858
859
860
    )
    return out


def quantize_dbias(
861
    dz: Union[jnp.ndarray, NoScaleTensor],
862
863
    quantizer: Quantizer,
    is_dbias: bool = True,
864
    flatten_axis: int = -1,
865
    amax_scope: AmaxScope = AmaxScope.LOCAL,
866
867
868
869
870
871
872
873
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
    """Quantize input tensor and compute bias gradient.

    Args:
        dz: Input tensor to be quantized and used for bias gradient computation.
            Shape: (..., K) where K is the hidden size.
        quantizer: Quantizer for FP8 quantization of the output.
        is_dbias: If True, compute bias gradient. Defaults to True.
874
875
        flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
            Defaults to -1.
876
877
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.

878
879
880
881
882
883
884
885

    Returns:
        A tuple containing:
        - A ScaledTensor containing the quantized input tensor.
            The ScaledTensor includes both the quantized data and scaling factors.
        - The bias gradient tensor.
            Shape: (K,) or empty if is_dbias is False.
    """
886
    return _quantize_dbias_impl(
Alp Dener's avatar
Alp Dener committed
887
888
889
890
        dz,
        quantizer=quantizer,
        is_dbias=is_dbias,
        flatten_axis=flatten_axis,
891
        amax_scope=amax_scope,
892
    )
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986


class GroupedQuantizePrimitive(BasePrimitive):
    """
    Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias
    """

    name = "te_grouped_quantize_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
    )  # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(
        x_aval,
        scale_aval,
        group_sizes_aval,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        out_shape = math.prod(x_aval.shape)
        # TODO(Phuong): can scale_aval be None?
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_grouped_scale_shape_2x(
            x_aval.shape,
            group_sizes_aval.size,
            group_axis,
            is_padded=True,
            flatten_axis=flatten_axis,
        )

        if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            rowwise_out_shape = out_shape
        else:
            rowwise_out_shape = (1,)
            rowwise_scale_inv_shape = (1,)
        rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)

        amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32)

        if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
            colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        rowwise_scale_inv_aval = jax.core.ShapedArray(
            shape=rowwise_scale_inv_shape, dtype=scale_dtype
        )
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )

        return (
            rowwise_out_aval,
            colwise_out_aval,
            rowwise_scale_inv_aval,
            colwise_scale_inv_aval,
            amax_aval,
        )

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_quantize_p outer primitive abstract
        """
        # Phuong: keeping outer abstract so that we can add fuse dbias later
        (
            rowwise_out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
987
        ) = GroupedQuantizePrimitive.abstract(*args, **kwargs)
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
        return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax

    @staticmethod
    def lowering(
        ctx,
        x,
        scale,
        group_sizes,
        *,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p lowering rules
        """
        del out_dtype, scale_dtype
        x_aval, scale_aval, group_sizes_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert scale_aval.dtype == jnp.float32
        assert group_sizes_aval.dtype == jnp.int32
        assert group_axis == 0
        return ffi.ffi_lowering(GroupedQuantizePrimitive.name)(
            ctx,
            x,
            scale,
            group_sizes,
            scaling_mode=scaling_mode.value,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
        )

    @staticmethod
    def impl(
        x,
        scale,
        group_sizes,
        out_dtype,
        scaling_mode,
        q_layout,
        flatten_axis,
        group_axis,
        scale_dtype,
    ):
        """
        te_dbias_quantize_p implementation
        """
        assert GroupedQuantizePrimitive.inner_primitive is not None
        (
            rowwise_out,
            colwise_out,
            rowwise_scale_inv,
            colwise_scale_inv,
            updated_amax,
        ) = GroupedQuantizePrimitive.inner_primitive.bind(
            x,
            scale,
            group_sizes,
            out_dtype=out_dtype,
            scaling_mode=scaling_mode,
            q_layout=q_layout,
            flatten_axis=flatten_axis,
            group_axis=group_axis,
            scale_dtype=scale_dtype,
        )
        return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax)


register_primitive(GroupedQuantizePrimitive)


def grouped_quantize(
    x: jnp.ndarray,
    quantizer: GroupedQuantizer,
    group_sizes: jnp.ndarray = None,
1066
    amax: jnp.ndarray = None,
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    flatten_axis: int = -1,
) -> GroupedScaledTensor1x:
    """Quantize a tensor in grouped manner.

    This function quantizes a tensor by splitting it into groups along a specified axis
    and applying quantization to each group separately. The groups can be either specified
    explicitly through group_sizes or automatically split along the group_axis.

    Args:
        x: Input tensor to quantize
        quantizer: The quantizer to use for quantization
        group_sizes: Array of ints containing the size of each group (default: None)
1079
        amax: The amax of x; if None, it is auto-generated. (default: None)
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)

    Returns:
        A GroupedScaledTensor1x containing the quantized data

    Note:
        - If group_sizes is not provided, the tensor will be split into equal-sized groups
          along the group_axis
        - The group_axis is currently fixed to 0
        - The quantizer's q_layout determines whether row-wise, column-wise, or both
          quantization is applied
    """

    if quantizer is None:
1094
1095
1096
        if isinstance(x, NoScaleTensor):
            return x
        return NoScaleTensor(data=x, amax=None)
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123

    # TODO(Phuong): add support for flatten_axis = -2
    assert flatten_axis in (
        -1,
        x.ndim - 1,
    ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}"
    group_axis = 0

    if group_sizes is None:
        group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)

    if not GroupedQuantizePrimitive.enabled():
        return quantizer.quantize(
            x, flatten_axis=flatten_axis, group_sizes=group_sizes, group_axis=group_axis
        )
    n_groups = group_sizes.size
    original_shape = x.shape
    assert n_groups == len(
        quantizer.quantizers
    ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
    scale = jnp.empty((n_groups,), jnp.float32)

    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            scale = scale.at[i].set(quantizer_i.scale[0])

    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
1124
1125
1126
1127
        if amax is not None:
            row_amax = amax
        else:
            row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
1128
1129
1130
        segment_ids = jnp.repeat(
            jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
        )
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
        grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups)
        for i in range(n_groups):
            tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype)
            scale = scale.at[i].set(tmp_scale[0])

    is_tensor_scaling = quantizer.scaling_mode in (
        ScalingMode.DELAYED_TENSOR_SCALING,
        ScalingMode.CURRENT_TENSOR_SCALING,
    )
    # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet
    # So we performance ROWWISE_COLWISE and use the colwise_tensor_output
    apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE
    q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = GroupedQuantizePrimitive.outer_primitive.bind(
        x,
        scale,
        group_sizes,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        q_layout=q_layout.value,
        flatten_axis=flatten_axis,
        group_axis=group_axis,
        scale_dtype=quantizer.get_scale_dtype(),
    )

    # For DelayedScaling2x and CurrentScaling2x, the scale buffer
    # is shared between rowwise and colwise
    if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war:
        colwise_scale_inv = rowwise_scale_inv

    # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead?
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
        for i, quantizer_i in enumerate(quantizer.quantizers):
            quantizer_i.update(updated_amax[i].reshape((1,)))

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=flatten_axis,
        group_sizes=group_sizes,
        original_shape=original_shape,
        group_axis=group_axis,
    )
    return out
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202


def grouped_dbias(grad: jnp.ndarray, group_sizes: jnp.ndarray) -> jnp.ndarray:
    """
    Compute the grouped bias gradient.

    Args:
        grad: jnp.ndarray of shape (M, N)
        group_sizes: jnp.ndarray of shape(num_groups,), sum(group_sizes) == M

    Returns:
        dbias: jnp.ndarray of shape (num_groups, N)
    """
    assert grad.ndim == 2, "Input grad must be a 2D tensor."
    assert group_sizes.ndim == 1, "group_sizes must be a 1D tensor."

1203
1204
1205
    segment_ids = jnp.repeat(
        jnp.arange(group_sizes.size), group_sizes, total_repeat_length=grad.shape[0]
    )
1206
1207
1208
1209
    grad_fp32 = grad.astype(jnp.float32)
    dbias_fp32 = jax.ops.segment_sum(grad_fp32, segment_ids, num_segments=group_sizes.shape[0])
    dbias = dbias_fp32.astype(grad.dtype)
    return dbias