activation.py 36.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.sharding import PartitionSpec
14

15
16
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
17
18
19
20

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
21
    te_dtype_to_jax_dtype,
22
    get_padded_spec,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
    QuantizeAxis,
    DelayedScaleQuantizer,
    ScalingMode,
37
38
)

39
40
41
42
43
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

44
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
45
46
47


ActivationEnum = {
48
49
50
51
52
53
54
55
56
57
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
58
59
60
}


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


76
77
class ActLuPrimitive(BasePrimitive):
    """
78
    ActLu Primitive
79
    """
80

81
82
83
84
85
86
87
88
89
90
91
92
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
93
94
95
96
    inner_primitive = None
    outer_primitive = None

    @staticmethod
97
98
99
100
101
102
103
104
105
106
107
108
109
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
110
        """
111
        te_act_lu_p abstract
112
        """
113
        del act_enum, act_len, scale_shapes
114
115
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
116
117
118
119
120
121
122
123
124
        assert scale_aval is None or scale_aval.dtype == jnp.float32

        out_shape = (
            *x_aval.shape[:-2],
            1,
            x_aval.shape[-1],
        )
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
125

126
127
128
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer)
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        if len(rowwise_scale_inv_shape) > 1:
            rowwise_scale_inv_shape = (
                rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
            )
        if len(colwise_scale_inv_shape) > 1:
            colwise_scale_inv_shape = (
                colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
            )

        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)

        colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
        if is_2x:
            colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
            colwise_scale_inv_aval = jax.core.ShapedArray(
                shape=colwise_scale_inv_shape, dtype=scale_dtype
            )

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
150
151

    @staticmethod
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
166
        """
167
        te_gated_act_lu_p lowering rules
168
        """
169
170
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        x_aval, scale_aval = ctx.avals_in
171
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
172
        assert scale_aval is None or scale_aval.dtype == jnp.float32
173

174
175
176
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x
        )
177
        return out
178
179

    @staticmethod
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
196
        assert ActLuPrimitive.inner_primitive is not None
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False)
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            rowwise_scale_inv_shape = (
                rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
            )
            if is_2x:
                colwise_scale_inv_shape = (
                    colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
                )
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
231
232

    @staticmethod
233
234
235
236
237
238
239
240
241
242
243
244
245
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
246
        """
247
        to describe batch rules for vmap
248
        """
249
        del act_len, is_outer
250
251
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
252
253
254
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
255

256
257
258
259
260
261
262
263
264
265
266
267
268
269
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
            ),
            out_bdims,
        )
270
271

    @staticmethod
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            scale_shapes,
            act_len,
            is_outer,
        )  # Unused.
294
        x_spec = get_padded_spec(arg_infos[0])
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        out_spec = (*x_spec[:-2], None, x_spec[-2])
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_out_spec = multidim_transpose(out_spec)
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
        )
        amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")

        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_sharding = NamedSharding(
                mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
            )
        colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
            "ActLuPrimitive.colwise_scale_inv"
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
326
327

    @staticmethod
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
342
        x_spec = get_padded_spec(arg_infos[0])
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        out_spec = (*x_spec[:-1], x_spec[-1])
        if act_len == 2 and x_spec[-1] is None:
            # Ensure last axis is partitioned and not the gating axis
            out_spec = (*x_spec[:-2], None, x_spec[-2])
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_out_spec = multidim_transpose(out_spec)
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv"
        )
        amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax")
362

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_sharding = NamedSharding(
                mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
            )
        colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
            "ActLuPrimitive.colwise_scale_inv"
        )
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec))
        arg_shardings = tuple(arg_shardings)
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
380

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_outer=True,
                )
            )
396

397
398
399
400
401
402
403
404
405
406
407
408
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
409

410
        return mesh, sharded_impl, out_shardings, arg_shardings
411

412

413
register_primitive(ActLuPrimitive)
414
415


416
class DActLuDBiasQuantizePrimitive(BasePrimitive):
417
    """
418
    DActLu DBias Cast Transpose Primitive
419
    """
420

421
422
423
424
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
    # out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
425
426
427
428
    inner_primitive = None
    outer_primitive = None

    @staticmethod
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
444
        """
445
        te_dact_dbias_quantize_p abstract
446
        """
447
        del act_enum, scale_shapes
448
449
450
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        assert scale_aval.dtype == jnp.float32
        ir_hidden_size = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert act_len * ir_hidden_size == gi_hidden_size
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)

        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)

        colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)

        dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
        if is_2x:
            # Don't transpose output for MXFP8
            if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
                t_shape = out_shape
            else:
                t_shape = multidim_transpose(out_shape)
            colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
            colwise_scale_inv_aval = jax.core.ShapedArray(
                shape=colwise_scale_inv_shape, dtype=scale_dtype
            )
480

481
482
483
484
485
486
487
488
489
490
491
492
493
494
        if is_dbias:
            dbias_shape = gi_hidden_size
            dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
            wkspace_aval = x_aval.update(
                shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
            )
495

496
497
498
499
500
501
502
503
504
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
505
506

    @staticmethod
507
    def outer_abstract(*args, **kwargs):
508
        """
509
        te_dact_dbias_quantize_p outer abstract
510
        """
511
512
513
514
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
515
516

    @staticmethod
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
533
        """
534
        te_dact_dbias_quantize_p lowering rules
535
        """
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
        return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
            ctx,
            dz,
            x,
            scale,
            scaling_mode=scaling_mode,
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
551
552

    @staticmethod
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
567
        """
568
        te_dact_dbias_quantize_p impl
569
        """
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        del is_outer
        assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.inner_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
        ).get_scale_shape_2x(x.shape, is_padded=False)
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv = jax.lax.slice(
                scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
            )
            if is_2x:
                colwise_scale_inv = jax.lax.slice(
                    colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
                )
        return (
            out,
            colwise_out,
            scale_inv,
            colwise_scale_inv,
            updated_amax,
            dbias,
        )  # Exclude wkspace
607
608

    @staticmethod
609
610
611
612
613
614
615
616
617
618
619
620
621
622
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
623
        """
624
        to describe batch rules for vmap
625
        """
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
        del is_outer
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
            DActLuDBiasQuantizePrimitive.outer_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
656
657

    @staticmethod
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
        del scale_dtype, scale_shapes, is_dbias, act_len, is_outer
        x_spec = get_padded_spec(arg_infos[1])

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_x_spec = multidim_transpose(x_spec)
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

        dbias_shaprding = NamedSharding(
            mesh,
            PartitionSpec(x_spec[-1]),
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
        )
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
        )
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_sharding = NamedSharding(
                mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
            )
        colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
            "DActLuDBiasQuantizePrimitive.colwise_scale_inv"
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_shaprding,
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out")
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_x_spec = multidim_transpose(x_spec)
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

        dbias_shaprding = NamedSharding(
            mesh,
            PartitionSpec(x_spec[-1]),
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
        scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv"
        )
        amax_sharding = NamedSharding(
            mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
        )
        if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_sharding = NamedSharding(
                mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
            )
        colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
            "DActLuDBiasQuantizePrimitive.colwise_scale_inv"
        )

765
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
766
767
768
769
770
771
772
773
774
775
776
777
778
        arg_shardings = (
            arg_shardings[1],
            arg_shardings[1],
            *arg_shardings[2:],
        )  # dz and x are the same
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
            dbias_shaprding,
        )
779

780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
                DActLuDBiasQuantizePrimitive.impl(
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
808
809
810
811

        return mesh, sharded_impl, out_shardings, arg_shardings


812
register_primitive(DActLuDBiasQuantizePrimitive)
813
814


815
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
816
    """
817
    JAX native activation implementation
818
    """
819
820
821
822
823
824
825
826
827
    x = jnp.split(inputs, len(activation_type), axis=-1)
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
    if quantizer:
        return quantizer.quantize(x)
    return x
828
829


830
831
832
833
834
835
836
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
837
    """
838
    JAX implementation of dact_lu and dbias with optional quantization
839
    """
840
841
842
843
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
844

845
846
847
    dbias = None
    if is_dbias:
        dbias = _jax_dbias(dx).astype(x.dtype)
848

849
850
851
852
    if quantizer is not None:
        dx = quantizer.quantize(dx, dq_dtype=x.dtype)
    else:
        dx = dx.astype(x.dtype)
853

854
    return dx, dbias
855
856


857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
876

877
878
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
879

880
881
882
    # TE/common does not support colwise-only quantization yet
    if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
        return _jax_act_lu(x, activation_type, quantizer)
883

884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
    output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type))

    if quantizer is None:
        x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
            act_len=len(activation_type),
            scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
            is_2x=False,
            scale_dtype=jnp.float32,
            scale_shapes=((), ()),
            is_outer=True,
907
        )
908
909
        out = out.reshape(output_shape)
        return out
910

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
        act_len=len(activation_type),
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        scale_shapes=quantizer.get_scale_shapes(output_shape),
        is_outer=True,
    )
933

934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
    rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
    if len(rowwise_scale_inv.shape) > 1:
        rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2)  # Remove act axis
    if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
        colwise_output_shape = output_shape
        if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
            colwise_output_shape = multidim_transpose(output_shape)
        colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
        if len(colwise_scale_inv.shape) > 1:
            colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2)  # Remove act axis

    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_axis=quantizer.q_axis,
        layout=quantizer.get_layout(),
    )
957
958


959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
            Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
981

982
983
    if not DActLuDBiasQuantizePrimitive.enabled():
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
984

985
986
987
    # TE/common does not support colwise-only quantization yet
    if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE:
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
988

989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
        out, _ = quantize_dact_dbias(
            dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None
        )
        return quantize_dbias(out, is_dbias=True, quantizer=quantizer)

    is_gated = len(activation_type) == 2
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
        )
        if war_output is not None:
            return war_output

    scale = jnp.empty((), jnp.float32)

    act_type_id = ActivationEnum[activation_type]

    if quantizer is None:
        output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            scale_shapes=((), ()),  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=len(activation_type),
            is_outer=True,
        )
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output).astype(x.dtype)
        return output.astype(x.dtype), dbias

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
        # TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
        if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
            out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype)
        else:
            out, dbias = quantize_dbias(
                dgated,
                quantizer=quantizer,
                is_dbias=True,
                dq_dtype=x.dtype,
            )
        return out, dbias

    out_shape = x.shape

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        scale_shapes=quantizer.get_scale_shapes(out_shape),
        is_dbias=is_dbias,
        act_enum=act_type_id,
        act_len=len(activation_type),
        is_outer=True,
    )
1079

1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
    if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
        q_axis=quantizer.q_axis,
        layout=quantizer.get_layout(),
    )
1096

1097
    return out, dbias
1098
1099


1100
1101
def dact_lu(
    dz: jnp.ndarray,
1102
1103
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1104
1105
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1106
    """
1107
    Backward pass for activation with optional quantization.
1108

1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1124
    )
1125
    return output