"vscode:/vscode.git/clone" did not exist on "4abf6336ec65c270343eb895e7b18786e9274176"
activation.py 39.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
17
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl
31
32
33
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40
41
42
43
44
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

45
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
46
47
48


ActivationEnum = {
49
50
51
52
53
54
55
56
57
58
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
59
60
61
}


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


77
78
class ActLuPrimitive(BasePrimitive):
    """
79
    ActLu Primitive
80
    """
81

82
83
84
85
86
87
88
89
90
91
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
92
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
93
94
95
96
    inner_primitive = None
    outer_primitive = None

    @staticmethod
97
98
99
100
101
102
103
104
105
106
107
108
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
109
        """
110
        te_act_lu_p abstract
111
        """
112
        del act_enum
113
114
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
115
        assert scale_aval is None or scale_aval.dtype == jnp.float32
116
117
118
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
119
        )
120

121
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
122
123
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
124
125
        )

126
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
127
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
128

129
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
130

131
132
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
133
134
135
136
137
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
138
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
139
140
141
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
142
143

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
144
145

    @staticmethod
146
147
148
149
150
151
152
153
154
155
156
157
158
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
159
        """
160
        te_gated_act_lu_p lowering rules
161
        """
162
        del out_dtype, scale_dtype, act_len, is_outer
163
        x_aval, scale_aval = ctx.avals_in
164
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
165
        assert scale_aval is None or scale_aval.dtype == jnp.float32
166

167
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
168
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
169
        )
170
        return out
171
172

    @staticmethod
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
188
        assert ActLuPrimitive.inner_primitive is not None
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
205
206
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
207
208
209
210
211
212
213
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
214

215
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
216
217

    @staticmethod
218
219
220
221
222
223
224
225
226
227
228
229
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
230
        """
231
        to describe batch rules for vmap
232
        """
233
        del act_len, is_outer
234
235
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
236
237
238
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
239

240
241
242
243
244
245
246
247
248
249
250
251
252
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
            ),
            out_bdims,
        )
253
254

    @staticmethod
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
            is_outer,
        )  # Unused.
275
        x_spec = get_padded_spec(arg_infos[0])
276
277
278
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
279
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
280

281
        if is_2x:
282
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
283
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
284
285
286
287
288
289
290
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
291
292

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
293
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
294
            scale_inv_spec = amax_spec = scale_spec
295
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
296
297
298
299
300
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

301
        scale_inv_sharding = NamedSharding(
302
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
303
        )
304
305
306
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
307
        )
308

309
310
311
312
313
314
315
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
316
317

    @staticmethod
318
319
320
321
322
323
324
325
326
327
328
329
330
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
331
        x_spec = get_padded_spec(arg_infos[0])
332
333
334
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
335
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
336

337
        if is_2x:
338
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
339
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
340
341
342
343
344
345
346
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
347
348

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
349
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
350
            scale_inv_spec = amax_spec = scale_spec
351
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
352
353
354
355
356
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

357
        scale_inv_sharding = NamedSharding(
358
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
359
        )
360
361
362
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
363
        )
364
365

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
366
367
368
369
370
371
372
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_outer=True,
                )
            )
388

389
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
390
391
392
393
394
395
396
397
398
399
400
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
401

402
        return mesh, sharded_impl, out_shardings, arg_shardings
403

404
405
406
407
408
409
410
411
412
413
414
415
416
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
417
        del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
418
419
420

        x_rank = len(value_types[0].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
421
            x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        )
        x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
        out = (*x_axes[:-2], x_axes[-1])
        scale_inv = scale_rules.rowwise_rule
        colwise_scale_inv = scale_rules.colwise_rule

        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(
                    multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
                )
            else:
                colwise_out = out
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        # amax is always a unit tensor.
        amax = ("l",)

        return SdyShardingRule(
            (
                x_axes,
                "…1",
            ),
            (out, colwise_out, scale_inv, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
        )

451

452
register_primitive(ActLuPrimitive)
453
454


455
# TODO(Jeremy): replace is_2x with q_layout
456
class DActLuDBiasQuantizePrimitive(BasePrimitive):
457
    """
458
    DActLu DBias Cast Transpose Primitive
459
    """
460

461
462
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
463
464
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
465
466
467
468
    inner_primitive = None
    outer_primitive = None

    @staticmethod
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
483
        """
484
        te_dact_dbias_quantize_p abstract
485
        """
486
        del act_enum
487
488
489
490
491
492
493
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
494
        assert scale_aval.dtype == jnp.float32
495
496
497
498
499
500

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

501
        ir_hidden_size = dz_aval.shape[-1]
502
        gi_hidden_size = act_len * x_aval.shape[-1]
503
504
505
        assert act_len * ir_hidden_size == gi_hidden_size
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
506

507
508
509
510
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
511
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
512
        if is_2x:
513
            if ScalingMode(scaling_mode).is_tensor_scaling():
514
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
515
            else:
516
517
518
519
520
521
522
523
524
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
525

526
        if is_dbias:
527
            dbias_shape = (act_len, ir_hidden_size)
528
529
530
531
532
533
534
535
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
536
537
538
539
540
541
542
543
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
544

545
546
547
548
549
550
551
552
553
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
554
555

    @staticmethod
556
    def outer_abstract(*args, **kwargs):
557
        """
558
        te_dact_dbias_quantize_p outer abstract
559
        """
560
561
562
563
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
564
565

    @staticmethod
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
581
        """
582
        te_dact_dbias_quantize_p lowering rules
583
        """
584
        del out_dtype, scale_dtype, act_len, is_outer
585
586
587
588
589
590
591
592
593
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
        return ffi.ffi_lowering(DActLuDBiasQuantizePrimitive.name)(
            ctx,
            dz,
            x,
            scale,
594
            scaling_mode=scaling_mode.value,
595
596
597
598
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
599
600

    @staticmethod
601
602
603
604
605
606
607
608
609
610
611
612
613
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
614
        """
615
        te_dact_dbias_quantize_p impl
616
        """
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        del is_outer
        assert DActLuDBiasQuantizePrimitive.inner_primitive is not None
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
            DActLuDBiasQuantizePrimitive.inner_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
636
637
638
639
640
641
642
643
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
644
            )
645
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
646
647

    @staticmethod
648
649
650
651
652
653
654
655
656
657
658
659
660
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
661
        """
662
        to describe batch rules for vmap
663
        """
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        del is_outer
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasQuantizePrimitive.outer_primitive is not None
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
            DActLuDBiasQuantizePrimitive.outer_primitive.bind(
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
693
694

    @staticmethod
695
696
697
698
699
700
701
702
703
704
705
706
707
708
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
709
        del scale_dtype, act_len, is_outer
710
        x_spec = get_padded_spec(arg_infos[1])
711
        scale_spec = get_padded_spec(arg_infos[2])
712

713
714
715
716
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

717
718
719
720
        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )
        if is_2x:
721
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
722
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
723
724
725
726
727
728
729
730
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

731
732
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
733
            mesh,
734
            PartitionSpec(*dbias_spec),
735
736
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
737
738

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
739
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
740
            scale_inv_spec = amax_spec = scale_spec
741
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
742
743
744
745
746
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

747
        scale_inv_sharding = NamedSharding(
748
            mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
749
750
        )
        amax_sharding = NamedSharding(
751
            mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
752
        )
753
754
755
756
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
            desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
757
758
759
760
761
762
763
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
764
            dbias_sharding,
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
783
784
785
786
787
788
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
        )

789
        if is_2x:
790
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
791
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
792
793
794
795
796
797
798
799
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
        )

800
801
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
802
            mesh,
803
            PartitionSpec(*dbias_spec),
804
805
            desc="DActLuDBiasQuantizePrimitive.dbias",
        )
806
807

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
808
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
809
            scale_inv_spec = amax_spec = scale_spec
810
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
811
812
813
814
815
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

816
        scale_inv_sharding = NamedSharding(
817
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
818
        )
819
820
821
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
822
823
        )

824
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
825

826
827
828
829
830
831
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
832
            dbias_sharding,
833
        )
834

835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
                DActLuDBiasQuantizePrimitive.impl(
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

856
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
857
858
859
860
861
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
862
863
864

        return mesh, sharded_impl, out_shardings, arg_shardings

865
866
867
868
869
870
871
872
873
874
875
876
877
878
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
879
        del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
880
881
882

        x_rank = len(value_types[1].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
883
            x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
        )
        x_axes = scale_rules.input_spec
        out = x_axes
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
                colwise_out = tuple(x_axes)
        else:
            colwise_out = ("j",)

        dbias = x_axes[-2:] if is_dbias else ("k",)
        amax = ("…4",)

        return SdyShardingRule(
            (("…0",), tuple(x_axes), ("…2",)),
            (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
            **scale_rules.factor_sizes,
        )

904

905
register_primitive(DActLuDBiasQuantizePrimitive)
906
907


908
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
909
    """
910
    JAX native activation implementation
911
    """
912
913
914
915
916
917
918
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
919
920
921
922
923
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
924
    x = jnp.squeeze(x, axis=-2)
925
    if quantizer:
926
        return quantizer.quantize(x, flatten_axis=-1)
927
    return x
928
929


930
931
932
933
934
935
936
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
937
    """
938
    JAX implementation of dact_lu and dbias with optional quantization
939
    """
940
941
942
943
944
945
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

946
947
948
949
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
950

951
952
    dbias = None
    if is_dbias:
953
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
954

955
    if quantizer is not None:
956
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
957
958
    else:
        dx = dx.astype(x.dtype)
959

960
    return dx, dbias
961
962


963
964
965
966
967
968
969
970
971
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
972
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
973
974
975
976
977
978
979
980
981
982
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
983
984
985
986
987
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
988

989
990
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
991

992
    # TE/common does not support colwise-only quantization yet
993
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
994
        return _jax_act_lu(x, activation_type, quantizer)
995

996
997
998
999
1000
1001
1002
1003
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1004
    output_shape = (*x.shape[:-2], x.shape[-1])
1005
1006
1007
1008
1009
1010
1011

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1012
            act_len=act_len,
1013
            scaling_mode=ScalingMode.NO_SCALING.value,
1014
1015
1016
            is_2x=False,
            scale_dtype=jnp.float32,
            is_outer=True,
1017
        )
1018
1019
        out = out.reshape(output_shape)
        return out
1020

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
        return out

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1045
        act_len=act_len,
1046
1047
1048
1049
1050
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_outer=True,
    )
1051

1052
1053
1054
1055
1056
1057
1058
1059
1060
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1061
1062
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1063
    )
1064
1065


1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1078
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1079
1080
1081
1082
1083
1084
1085
1086
1087
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1088

1089
1090
1091
1092
1093
1094
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1095
1096
    if not DActLuDBiasQuantizePrimitive.enabled():
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1097

1098
    # TE/common does not support colwise-only quantization yet
1099
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1100
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1101

1102
1103
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1104
1105
1106
1107
1108
1109
        out = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
        )
        return _quantize_dbias_impl(
            out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1110

1111
    is_gated = act_len == 2
1112
1113
1114
1115
1116
1117
1118
1119
1120
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1121
            flatten_axis=-2,
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        )
        if war_output is not None:
            return war_output

    scale = jnp.empty((), jnp.float32)

    act_type_id = ActivationEnum[activation_type]

    if quantizer is None:
        output, _, _, _, _, _ = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
1138
            scaling_mode=ScalingMode.NO_SCALING.value,
1139
1140
1141
1142
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
1143
            act_len=act_len,
1144
1145
1146
1147
            is_outer=True,
        )
        dbias = None
        if is_dbias:
1148
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
1149
1150
        return output.astype(x.dtype), dbias

1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
            dz=dz.astype(jnp.float32),
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, dbias = _quantize_dbias_impl(
            out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
        )
        return out, dbias

1164
1165
1166
1167
1168
1169
1170
1171
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1172
1173
1174
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
    ) = DActLuDBiasQuantizePrimitive.outer_primitive.bind(
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1194
        act_len=act_len,
1195
1196
        is_outer=True,
    )
1197

1198
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1199
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1211
1212
1213
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1214
    )
1215

1216
    return out, dbias
1217
1218


1219
1220
def dact_lu(
    dz: jnp.ndarray,
1221
1222
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1223
1224
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1225
    """
1226
    Backward pass for activation with optional quantization.
1227

1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1243
    )
1244
    return output