activation.py 36.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.sharding import PartitionSpec
14

15
16
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
17
18
19
20

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
21
    te_dtype_to_jax_dtype,
22
    get_padded_spec,
23
24
25
26
27
28
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
29
from .quantization import _jax_dbias, _quantize_dbias_impl
30
31
32
33
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
34
    QuantizeLayout,
35
36
    DelayedScaleQuantizer,
    ScalingMode,
37
38
)

39
40
41
42
43
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

44
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
45
46
47


ActivationEnum = {
48
49
50
51
52
53
54
55
56
57
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
58
59
60
}


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


76
77
class ActLuPrimitive(BasePrimitive):
    """
78
    ActLu Primitive
79
    """
80

81
82
83
84
85
86
87
88
89
90
91
92
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
93
94
95
96
    inner_primitive = None
    outer_primitive = None

    @staticmethod
97
98
99
100
101
102
103
104
105
106
107
108
109
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
110
        """
111
        te_act_lu_p abstract
112
        """
113
        del act_enum, scale_shapes
114
115
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
116
        assert scale_aval is None or scale_aval.dtype == jnp.float32
117
118
119
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
120
        )
121
122

        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
123
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
124

125
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
126

127
128
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
129
130
131
132
133
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
134
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
135
136
137
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
138
139

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
140
141

    @staticmethod
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
156
        """
157
        te_gated_act_lu_p lowering rules
158
        """
159
160
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        x_aval, scale_aval = ctx.avals_in
161
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
162
        assert scale_aval is None or scale_aval.dtype == jnp.float32
163

164
165
166
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x
        )
167
        return out
168
169

    @staticmethod
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
186
        assert ActLuPrimitive.inner_primitive is not None
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
204
205
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
206
207
208
209
210
211
212
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
213

214
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
215
216

    @staticmethod
217
218
219
220
221
222
223
224
225
226
227
228
229
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
230
        """
231
        to describe batch rules for vmap
232
        """
233
        del act_len, is_outer
234
235
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
236
237
238
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
            ),
            out_bdims,
        )
254
255

    @staticmethod
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            scale_shapes,
            act_len,
            is_outer,
        )  # Unused.
278
        x_spec = get_padded_spec(arg_infos[0])
279
280
281
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
282
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
283

284
285
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
286
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
287
288
289
290
291
292
293
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
294
295
296
297
298
299
300
301
302
303

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
            scale_inv_spec = amax_spec = scale_spec
        elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

304
        scale_inv_sharding = NamedSharding(
305
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
306
        )
307
308
309
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
310
        )
311

312
313
314
315
316
317
318
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
319
320

    @staticmethod
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
335
        x_spec = get_padded_spec(arg_infos[0])
336
337
338
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
339
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
340

341
342
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
343
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
344
345
346
347
348
349
350
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
351
352
353
354
355
356
357
358
359
360

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
            scale_inv_spec = amax_spec = scale_spec
        elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

361
        scale_inv_sharding = NamedSharding(
362
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
363
        )
364
365
366
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
367
        )
368
369

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
370
371
372
373
374
375
376
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
377

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_outer=True,
                )
            )
393

394
395
396
397
398
399
400
401
402
403
404
405
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
406

407
        return mesh, sharded_impl, out_shardings, arg_shardings
408

409

410
register_primitive(ActLuPrimitive)
411
412


413
# TODO(Jeremy): replace is_2x with q_layout
414
class DActLuDBiasQuantizePrimitive(BasePrimitive):
415
    """
416
    DActLu DBias Cast Transpose Primitive
417
    """
418

419
420
421
422
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
    # out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
423
424
425
426
    inner_primitive = None
    outer_primitive = None

    @staticmethod
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
442
        """
443
        te_dact_dbias_quantize_p abstract
444
        """
445
        del act_enum, scale_shapes
446
447
448
449
450
451
452
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
453
454
        assert scale_aval.dtype == jnp.float32
        ir_hidden_size = dz_aval.shape[-1]
455
        gi_hidden_size = act_len * x_aval.shape[-1]
456
457
458
        assert act_len * ir_hidden_size == gi_hidden_size
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
459

460
461
462
463
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
464
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
465
        if is_2x:
466
467
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
468
            else:
469
470
471
472
473
474
475
476
477
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
478

479
        if is_dbias:
480
            dbias_shape = (act_len, ir_hidden_size)
481
482
483
484
485
486
487
488
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
489
490
491
492
493
494
495
496
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
497

498
499
500
501
502
503
504
505
506
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
507
508

    @staticmethod
509
    def outer_abstract(*args, **kwargs):
510
        """
511
        te_dact_dbias_quantize_p outer abstract
512
        """
513
514
515
516
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
517
518

    @staticmethod
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
535
        """
536
        te_dact_dbias_quantize_p lowering rules
537
        """
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
        return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
            ctx,
            dz,
            x,
            scale,
            scaling_mode=scaling_mode,
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
553
554

    @staticmethod
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
569
        """
570
        te_dact_dbias_quantize_p impl
571
        """
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        del is_outer
        assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.inner_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
592
593
594
595
596
597
598
599
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
600
            )
601
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
602
603

    @staticmethod
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
618
        """
619
        to describe batch rules for vmap
620
        """
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        del is_outer
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
            DActLuDBiasQuantizePrimitive.outer_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
651
652

    @staticmethod
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
668
        del scale_dtype, scale_shapes, act_len, is_outer
669
        x_spec = get_padded_spec(arg_infos[1])
670
        scale_spec = get_padded_spec(arg_infos[2])
671
672
673
674
675
676

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
677
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
678
679
680
681
682
683
684
685
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

686
687
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
688
            mesh,
689
            PartitionSpec(*dbias_spec),
690
691
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
692
693
694
695
696
697
698
699
700
701

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
            scale_inv_spec = amax_spec = scale_spec
        elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

702
        scale_inv_sharding = NamedSharding(
703
            mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
704
705
        )
        amax_sharding = NamedSharding(
706
            mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
707
        )
708
709
710
711
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
712
713
714
715
716
717
718
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
719
            dbias_sharding,
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
739
740
741
742
743
744
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )

745
746
        if is_2x:
            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
747
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
748
749
750
751
752
753
754
755
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

756
757
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
758
            mesh,
759
            PartitionSpec(*dbias_spec),
760
761
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
762
763
764
765
766
767
768
769
770
771

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
        if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
            scale_inv_spec = amax_spec = scale_spec
        elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

772
        scale_inv_sharding = NamedSharding(
773
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
774
        )
775
776
777
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
778
779
        )

780
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
781

782
783
784
785
786
787
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
788
            dbias_sharding,
789
        )
790

791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
                DActLuDBiasQuantizePrimitive.impl(
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

            if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value:
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
819
820
821
822

        return mesh, sharded_impl, out_shardings, arg_shardings


823
register_primitive(DActLuDBiasQuantizePrimitive)
824
825


826
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
827
    """
828
    JAX native activation implementation
829
    """
830
831
832
833
834
835
836
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
837
838
839
840
841
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
842
    x = jnp.squeeze(x, axis=-2)
843
    if quantizer:
844
        return quantizer.quantize(x, flatten_axis=-1)
845
    return x
846
847


848
849
850
851
852
853
854
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
855
    """
856
    JAX implementation of dact_lu and dbias with optional quantization
857
    """
858
859
860
861
862
863
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

864
865
866
867
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
868

869
870
    dbias = None
    if is_dbias:
871
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
872

873
    if quantizer is not None:
874
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
875
876
    else:
        dx = dx.astype(x.dtype)
877

878
    return dx, dbias
879
880


881
882
883
884
885
886
887
888
889
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
890
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
891
892
893
894
895
896
897
898
899
900
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
901
902
903
904
905
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
906

907
908
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
909

910
    # TE/common does not support colwise-only quantization yet
911
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
912
        return _jax_act_lu(x, activation_type, quantizer)
913

914
915
916
917
918
919
920
921
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
922
    output_shape = (*x.shape[:-2], x.shape[-1])
923
924
925
926
927
928
929

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
930
            act_len=act_len,
931
932
933
934
935
            scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
            is_2x=False,
            scale_dtype=jnp.float32,
            scale_shapes=((), ()),
            is_outer=True,
936
        )
937
938
        out = out.reshape(output_shape)
        return out
939

940
941
942
943
944
945
946
947
948
949
950
951
952
953
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
954
        act_len=act_len,
955
956
957
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
958
959
        # output does not have act axis
        scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
960
961
        is_outer=True,
    )
962

963
964
965
966
967
968
969
970
971
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
972
973
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
974
    )
975
976


977
978
979
980
981
982
983
984
985
986
987
988
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
989
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
990
991
992
993
994
995
996
997
998
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
999

1000
1001
1002
1003
1004
1005
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1006
1007
    if not DActLuDBiasQuantizePrimitive.enabled():
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1008

1009
    # TE/common does not support colwise-only quantization yet
1010
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1011
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1012

1013
1014
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1015
1016
        out = dact_lu(dz, x, activation_type, quantizer=None)
        return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
1017

1018
    is_gated = act_len == 2
1019
1020
1021
1022
1023
1024
1025
1026
1027
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1028
            flatten_axis=-2,
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        )
        if war_output is not None:
            return war_output

    scale = jnp.empty((), jnp.float32)

    act_type_id = ActivationEnum[activation_type]

    if quantizer is None:
        output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            scale_shapes=((), ()),  # unused
            is_dbias=False,
            act_enum=act_type_id,
1051
            act_len=act_len,
1052
1053
1054
1055
            is_outer=True,
        )
        dbias = None
        if is_dbias:
1056
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
        return output.astype(x.dtype), dbias

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1067
1068
1069
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
        return out, dbias

    out_shape = x.shape

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
1089
1090
        # output has act axis
        scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
1091
1092
        is_dbias=is_dbias,
        act_enum=act_type_id,
1093
        act_len=act_len,
1094
1095
        is_outer=True,
    )
1096

1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
    if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1110
1111
1112
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1113
    )
1114

1115
    return out, dbias
1116
1117


1118
1119
def dact_lu(
    dz: jnp.ndarray,
1120
1121
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1122
1123
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1124
    """
1125
    Backward pass for activation with optional quantization.
1126

1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1142
    )
1143
    return output