activation.py 38.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
17
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl
31
32
33
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40
41
42
43
44
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

45
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
46
47
48


ActivationEnum = {
49
50
51
52
53
54
55
56
57
58
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
59
60
61
}


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


77
78
class ActLuPrimitive(BasePrimitive):
    """
79
    ActLu Primitive
80
    """
81

82
83
84
85
86
87
88
89
90
91
92
93
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
        9,
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
94
95
96
97
    inner_primitive = None
    outer_primitive = None

    @staticmethod
98
99
100
101
102
103
104
105
106
107
108
109
110
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
111
        """
112
        te_act_lu_p abstract
113
        """
114
        del act_enum, scale_shapes
115
116
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
117
        assert scale_aval is None or scale_aval.dtype == jnp.float32
118
119
120
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
121
        )
122
123

        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
124
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
125

126
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
127

128
129
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
130
131
132
133
134
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
135
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
136
137
138
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
139
140

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
141
142

    @staticmethod
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
157
        """
158
        te_gated_act_lu_p lowering rules
159
        """
160
161
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        x_aval, scale_aval = ctx.avals_in
162
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
163
        assert scale_aval is None or scale_aval.dtype == jnp.float32
164

165
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
166
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
167
        )
168
        return out
169
170

    @staticmethod
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
187
        assert ActLuPrimitive.inner_primitive is not None
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
205
206
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
207
208
209
210
211
212
213
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
214

215
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
216
217

    @staticmethod
218
219
220
221
222
223
224
225
226
227
228
229
230
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
    ):
231
        """
232
        to describe batch rules for vmap
233
        """
234
        del act_len, is_outer
235
236
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
237
238
239
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
            ),
            out_bdims,
        )
255
256

    @staticmethod
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            scale_shapes,
            act_len,
            is_outer,
        )  # Unused.
279
        x_spec = get_padded_spec(arg_infos[0])
280
281
282
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
283
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
284

285
        if is_2x:
286
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
287
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
288
289
290
291
292
293
294
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
295
296

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
297
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
298
            scale_inv_spec = amax_spec = scale_spec
299
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
300
301
302
303
304
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

305
        scale_inv_sharding = NamedSharding(
306
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
307
        )
308
309
310
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
311
        )
312

313
314
315
316
317
318
319
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
320
321

    @staticmethod
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
336
        x_spec = get_padded_spec(arg_infos[0])
337
338
339
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
340
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
341

342
        if is_2x:
343
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
344
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
345
346
347
348
349
350
351
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
352
353

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
354
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
355
            scale_inv_spec = amax_spec = scale_spec
356
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
357
358
359
360
361
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

362
        scale_inv_sharding = NamedSharding(
363
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
364
        )
365
366
367
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
368
        )
369
370

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
371
372
373
374
375
376
377
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
378

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_outer=True,
                )
            )
394

395
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
396
397
398
399
400
401
402
403
404
405
406
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
407

408
        return mesh, sharded_impl, out_shardings, arg_shardings
409

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
        del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types

        x_rank = len(value_types[0].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
            x_rank - 1, unique_var="i", flatten_axis=-2
        )
        x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
        out = (*x_axes[:-2], x_axes[-1])
        scale_inv = scale_rules.rowwise_rule
        colwise_scale_inv = scale_rules.colwise_rule

        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(
                    multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
                )
            else:
                colwise_out = out
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        # amax is always a unit tensor.
        amax = ("l",)

        return SdyShardingRule(
            (
                x_axes,
                "…1",
            ),
            (out, colwise_out, scale_inv, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
        )

458

459
register_primitive(ActLuPrimitive)
460
461


462
# TODO(Jeremy): replace is_2x with q_layout
463
class DActLuDBiasQuantizePrimitive(BasePrimitive):
464
    """
465
    DActLu DBias Cast Transpose Primitive
466
    """
467

468
469
470
471
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
    # out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
472
473
474
475
    inner_primitive = None
    outer_primitive = None

    @staticmethod
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
491
        """
492
        te_dact_dbias_quantize_p abstract
493
        """
494
        del act_enum, scale_shapes
495
496
497
498
499
500
501
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
502
503
        assert scale_aval.dtype == jnp.float32
        ir_hidden_size = dz_aval.shape[-1]
504
        gi_hidden_size = act_len * x_aval.shape[-1]
505
506
507
        assert act_len * ir_hidden_size == gi_hidden_size
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
508

509
510
511
512
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
513
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
514
        if is_2x:
515
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
516
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
517
            else:
518
519
520
521
522
523
524
525
526
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
527

528
        if is_dbias:
529
            dbias_shape = (act_len, ir_hidden_size)
530
531
532
533
534
535
536
537
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
538
539
540
541
542
543
544
545
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
546

547
548
549
550
551
552
553
554
555
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
556
557

    @staticmethod
558
    def outer_abstract(*args, **kwargs):
559
        """
560
        te_dact_dbias_quantize_p outer abstract
561
        """
562
563
564
565
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
566
567

    @staticmethod
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
584
        """
585
        te_dact_dbias_quantize_p lowering rules
586
        """
587
588
589
590
591
592
593
594
595
596
        del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
        return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
            ctx,
            dz,
            x,
            scale,
597
            scaling_mode=scaling_mode.value,
598
599
600
601
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
602
603

    @staticmethod
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
618
        """
619
        te_dact_dbias_quantize_p impl
620
        """
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        del is_outer
        assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.inner_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
641
642
643
644
645
646
647
648
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
649
            )
650
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
651
652

    @staticmethod
653
654
655
656
657
658
659
660
661
662
663
664
665
666
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
667
        """
668
        to describe batch rules for vmap
669
        """
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
        del is_outer
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
            DActLuDBiasQuantizePrimitive.outer_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                scale_shapes=scale_shapes,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
700
701

    @staticmethod
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
717
        del scale_dtype, scale_shapes, act_len, is_outer
718
        x_spec = get_padded_spec(arg_infos[1])
719
        scale_spec = get_padded_spec(arg_infos[2])
720
721
722
723
724

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )
        if is_2x:
725
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
726
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
727
728
729
730
731
732
733
734
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

735
736
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
737
            mesh,
738
            PartitionSpec(*dbias_spec),
739
740
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
741
742

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
743
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
744
            scale_inv_spec = amax_spec = scale_spec
745
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
746
747
748
749
750
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

751
        scale_inv_sharding = NamedSharding(
752
            mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
753
754
        )
        amax_sharding = NamedSharding(
755
            mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
756
        )
757
758
759
760
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
761
762
763
764
765
766
767
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
768
            dbias_sharding,
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
788
789
790
791
792
793
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )

794
        if is_2x:
795
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
796
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
797
798
799
800
801
802
803
804
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

805
806
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
807
            mesh,
808
            PartitionSpec(*dbias_spec),
809
810
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
811
812

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
813
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
814
            scale_inv_spec = amax_spec = scale_spec
815
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
816
817
818
819
820
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

821
        scale_inv_sharding = NamedSharding(
822
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
823
        )
824
825
826
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
827
828
        )

829
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
830

831
832
833
834
835
836
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
837
            dbias_sharding,
838
        )
839

840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
                DActLuDBiasQuantizePrimitive.impl(
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    scale_shapes=scale_shapes,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

862
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
863
864
865
866
867
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
868
869
870

        return mesh, sharded_impl, out_shardings, arg_shardings

871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        scale_shapes,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
        del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types

        x_rank = len(value_types[1].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
            x_rank, unique_var="i", flatten_axis=-2
        )
        x_axes = scale_rules.input_spec
        out = x_axes
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
                colwise_out = tuple(x_axes)
        else:
            colwise_out = ("j",)

        dbias = x_axes[-2:] if is_dbias else ("k",)
        amax = ("…4",)

        return SdyShardingRule(
            (("…0",), tuple(x_axes), ("…2",)),
            (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
            **scale_rules.factor_sizes,
        )

911

912
register_primitive(DActLuDBiasQuantizePrimitive)
913
914


915
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
916
    """
917
    JAX native activation implementation
918
    """
919
920
921
922
923
924
925
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
926
927
928
929
930
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
931
    x = jnp.squeeze(x, axis=-2)
932
    if quantizer:
933
        return quantizer.quantize(x, flatten_axis=-1)
934
    return x
935
936


937
938
939
940
941
942
943
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
944
    """
945
    JAX implementation of dact_lu and dbias with optional quantization
946
    """
947
948
949
950
951
952
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

953
954
955
956
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
957

958
959
    dbias = None
    if is_dbias:
960
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
961

962
    if quantizer is not None:
963
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
964
965
    else:
        dx = dx.astype(x.dtype)
966

967
    return dx, dbias
968
969


970
971
972
973
974
975
976
977
978
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
979
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
980
981
982
983
984
985
986
987
988
989
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
990
991
992
993
994
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
995

996
997
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
998

999
    # TE/common does not support colwise-only quantization yet
1000
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1001
        return _jax_act_lu(x, activation_type, quantizer)
1002

1003
1004
1005
1006
1007
1008
1009
1010
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1011
    output_shape = (*x.shape[:-2], x.shape[-1])
1012
1013
1014
1015
1016
1017
1018

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1019
            act_len=act_len,
1020
            scaling_mode=ScalingMode.NO_SCALING.value,
1021
1022
1023
1024
            is_2x=False,
            scale_dtype=jnp.float32,
            scale_shapes=((), ()),
            is_outer=True,
1025
        )
1026
1027
        out = out.reshape(output_shape)
        return out
1028

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1043
        act_len=act_len,
1044
1045
1046
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
1047
1048
        # output does not have act axis
        scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
1049
1050
        is_outer=True,
    )
1051

1052
1053
1054
1055
1056
1057
1058
1059
1060
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1061
1062
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1063
    )
1064
1065


1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1078
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1079
1080
1081
1082
1083
1084
1085
1086
1087
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1088

1089
1090
1091
1092
1093
1094
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1095
1096
    if not DActLuDBiasQuantizePrimitive.enabled():
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1097

1098
    # TE/common does not support colwise-only quantization yet
1099
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1100
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1101

1102
1103
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1104
1105
        out = dact_lu(dz, x, activation_type, quantizer=None)
        return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
1106

1107
    is_gated = act_len == 2
1108
1109
1110
1111
1112
1113
1114
1115
1116
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1117
            flatten_axis=-2,
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        )
        if war_output is not None:
            return war_output

    scale = jnp.empty((), jnp.float32)

    act_type_id = ActivationEnum[activation_type]

    if quantizer is None:
        output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
1134
            scaling_mode=ScalingMode.NO_SCALING.value,
1135
1136
1137
1138
1139
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            scale_shapes=((), ()),  # unused
            is_dbias=False,
            act_enum=act_type_id,
1140
            act_len=act_len,
1141
1142
1143
1144
            is_outer=True,
        )
        dbias = None
        if is_dbias:
1145
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
        return output.astype(x.dtype), dbias

    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1156
1157
1158
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
        return out, dbias

    out_shape = x.shape

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
1178
1179
        # output has act axis
        scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
1180
1181
        is_dbias=is_dbias,
        act_enum=act_type_id,
1182
        act_len=act_len,
1183
1184
        is_outer=True,
    )
1185

1186
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1187
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1199
1200
1201
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1202
    )
1203

1204
    return out, dbias
1205
1206


1207
1208
def dact_lu(
    dz: jnp.ndarray,
1209
1210
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1211
1212
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1213
    """
1214
    Backward pass for activation with optional quantization.
1215

1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1231
    )
1232
    return output