activation.py 56.5 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from dataclasses import dataclass
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
14
from jax.sharding import PartitionSpec
15

16
import numpy as np
17
18
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
19
20
21
from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, quantize, quantize_dbias, _quantize_dbias_impl, AmaxScope
31
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
32
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
33
34
35
36
from ..quantize import (
    Quantizer,
    DelayedScaleQuantizer,
    ScalingMode,
Paweł Gadziński's avatar
Paweł Gadziński committed
37
    QuantizeLayout,
38
39
)

40

41
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
42
43
44


ActivationEnum = {
45
46
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
Kim, Jin (Jay@SKT)'s avatar
Kim, Jin (Jay@SKT) committed
47
    ("sigmoid", "linear"): NVTE_Activation_Type.GLU,
48
49
50
51
52
53
54
55
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
56
    ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU,
57
58
59
}


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@dataclass(frozen=True)
class ClampedSwigluParams:
    """Parameters for the Clamped SwiGLU activation function
    used in GPT OSS."""

    limit: float = 7.0
    alpha: float = 1.702

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work.

        Returns:
            int: Hash value of the dataclass instance.
        """
        return hash((self.limit, self.alpha))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.

        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}


@dataclass(frozen=True)
class ActivationParams:
    """Parameters for various activation functions.
    Currently only Clamped SwiGLU activation has parameters.
    """

    clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()

    @staticmethod
    def create(activation_type, **kwargs):
        """Factory method to create ActivationParams based on activation_type."""
        CLAMPED_ACTIVATION_TYPES = {
            ("clamped_silu", "clamped_linear"),
            "clamped_silu",
            "clamped_linear",
        }
        if activation_type in CLAMPED_ACTIVATION_TYPES:
            return ActivationParams(ClampedSwigluParams(**kwargs))
        return ActivationParams()  # Default params for activations without parameters

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work"""
        return hash((self.clamped_swiglu,))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.
        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}


def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
120
121
122
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
123
124
125
126
127
128
    if fn_or_string == "clamped_linear":
        # This function is used for ClampedSwiGLU
        # used in GPT OSS where the gates are not only clamped
        # but also shifted by +1
        limit = act_params.clamped_swiglu.limit
        return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
129
130
131
132
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
133
134
135
136
    if fn_or_string == "clamped_silu":
        limit = act_params.clamped_swiglu.limit
        alpha = act_params.clamped_swiglu.alpha
        return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit)
137
138
139
140
141
142
143
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


144
145
class ActLuPrimitive(BasePrimitive):
    """
146
    ActLu Primitive
147
    """
148

149
150
151
152
153
154
155
156
157
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
158
        9,
159
160
161
162
        10,
        11,
        12,
        13,
163
    )  # out_dtype, act_enum, act_len, scaling_mode, quantize_layout, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
164
165
166
167
    inner_primitive = None
    outer_primitive = None

    @staticmethod
168
169
170
    def abstract(
        x_aval,
        scale_aval,
171
        amax_aval,
172
173
174
175
176
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
177
        quantize_layout,
178
        scale_dtype,
179
        act_params,
180
181
182
183
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
184
    ):
185
        """
186
        te_act_lu_p abstract
187
        """
188
189
190
191
        del act_enum, act_params, amax_scope, transpose_batch_sequence
        assert (
            not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value
        ), f"scaling_mode = {scaling_mode}"
192
193
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
194
        assert scale_aval is None or scale_aval.dtype == jnp.float32
195
        assert amax_aval is None or amax_aval.dtype == jnp.float32
196
197
198
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
199
        )
200

201
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
202
203
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
204
        )
205
206
207
208
209
210
211
        assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
            "NVFP4 block scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
        )
        assert (
            not quantize_layout.is_colwise_only
        ), "Fused activation with colwise-only quantization is not supported."
212

213
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
214
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
215

216
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
217

218
219
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
220
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
221
        if quantize_layout.is_rowwise_only:
222
223
224
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
225
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
226
227
228
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
229
230

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
231
232

    @staticmethod
233
234
235
236
    def lowering(
        ctx,
        x,
        scale,
237
        amax,
238
239
240
241
242
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
243
        quantize_layout,
244
        scale_dtype,
245
        act_params,
246
247
248
249
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
250
    ):
251
        """
252
        te_gated_act_lu_p lowering rules
253
        """
254
255
        del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
        x_aval, scale_aval, amax_aval = ctx.avals_in
256
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
257
        assert scale_aval is None or scale_aval.dtype == jnp.float32
258
259
260
261
262
263
        assert amax_aval.dtype == jnp.float32

        out = ffi.ffi_lowering(
            ActLuPrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
264
265
266
            ctx,
            x,
            scale,
267
            amax,
268
269
            act_enum=act_enum,
            scaling_mode=scaling_mode.value,
270
            quantize_layout=quantize_layout.value.value,
271
            act_params=act_params.to_ffi_lowering_dict(),
272
            output_amax_when_no_scaling=output_amax_when_no_scaling,
273
        )
274
        return out
275
276

    @staticmethod
277
278
279
    def impl(
        x,
        scale,
280
        amax,
281
282
283
284
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
285
        quantize_layout,
286
        scale_dtype,
287
        act_params,
288
289
290
291
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
292
293
294
295
296
    ):
        """
        to describe implementation
        """
        del is_outer
297
        assert ActLuPrimitive.inner_primitive is not None
298
299
300
301
302

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
303
                amax,
304
305
306
307
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
308
                quantize_layout=quantize_layout,
309
                scale_dtype=scale_dtype,
310
                act_params=act_params,
311
312
313
314
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=False,
315
316
317
318
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
319
320
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
321
322
323
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
324
        if quantize_layout.is_rowwise_colwise:
325
326
327
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
328

329
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
330
331

    @staticmethod
332
333
334
335
336
337
338
339
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
340
        quantize_layout,
341
        scale_dtype,
342
        act_params,
343
344
345
346
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
347
    ):
348
        """
349
        to describe batch rules for vmap
350
351
352
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
353
354
        x, scale, amax = batched_args
        x_bdim, scale_bdim, _ = batch_dims
355
        amax_bdim = scale_bdim
356

357
358
359
360
361
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
362
                amax,
363
364
                out_dtype=out_dtype,
                act_enum=act_enum,
365
                act_len=act_len,
366
                scaling_mode=scaling_mode,
367
                quantize_layout=quantize_layout,
368
                scale_dtype=scale_dtype,
369
                act_params=act_params,
370
371
372
373
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=is_outer,
374
375
376
            ),
            out_bdims,
        )
377
378

    @staticmethod
379
380
381
382
383
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
384
        quantize_layout,
385
        scale_dtype,
386
        act_params,
387
388
389
390
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
391
392
393
394
395
396
397
398
399
400
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
401
            act_params,
402
403
404
405
            amax_scope,
            transpose_batch_sequence,
            output_amax_when_no_scaling,
            is_outer,
406
        )  # Unused.
407
        x_spec = get_padded_spec(arg_infos[0])
408
409
410
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
411
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
412

413
        if quantize_layout.is_rowwise_colwise:
414
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
415
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
416
417
418
419
420
421
422
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
423
424

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
425
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
426
            scale_inv_spec = amax_spec = scale_spec
427
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
428
429
            scale_inv_spec = out_spec

430
        if quantize_layout.is_rowwise_colwise:
431
432
            colwise_scale_inv_spec = scale_inv_spec

433
        scale_inv_sharding = NamedSharding(
434
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
435
        )
436
437
438
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
439
        )
440

441
442
443
444
445
446
447
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
448
449

    @staticmethod
450
451
452
453
454
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
455
        quantize_layout,
456
        scale_dtype,
457
        act_params,
458
459
460
461
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
462
463
464
465
        mesh,
        arg_infos,
        result_infos,
    ):
466
        del result_infos, is_outer
467
        x_spec = get_padded_spec(arg_infos[0])
468
469
470
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
471
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
472

473
        if quantize_layout.is_rowwise_colwise:
474
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
475
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
476
477
478
479
480
481
482
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
483
484

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
485
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
486
            scale_inv_spec = amax_spec = scale_spec
487
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
488
489
            scale_inv_spec = out_spec

490
491
492
493
        if quantize_layout.is_rowwise_colwise:
            assert not ScalingMode(
                scaling_mode
            ).is_colwise_transposed, "Transpose layout scaling modes are not supported here yet"
494
495
            colwise_scale_inv_spec = scale_inv_spec

496
        scale_inv_sharding = NamedSharding(
497
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
498
        )
499
500
501
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
502
        )
503
504

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
505
506
507
508
509
510
511
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        def sharded_impl(x, scale, amax):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_updated_amax,
            ) = ActLuPrimitive.impl(
                x,
                scale,
                amax,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
528
                quantize_layout=quantize_layout,
529
530
531
532
533
534
                scale_dtype=scale_dtype,
                act_params=act_params,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=True,
535
            )
536

537
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
538
539
540
541
542
543
544
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(
                    local_updated_amax, mesh
                )
            elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
                global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                    local_updated_amax, out_spec, transpose_batch_sequence, mesh
                )
545
            else:
546
                global_updated_amax = local_updated_amax
547
548
549
550
551
552
553
554

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
555

556
        return mesh, sharded_impl, out_shardings, arg_shardings
557

558
559
560
561
562
563
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
564
        quantize_layout,
565
        scale_dtype,
566
        act_params,
567
568
569
570
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
571
572
573
574
        mesh,
        value_types,
        result_types,
    ):
575
576
577
578
579
580
581
582
583
584
585
586
587
        del (
            out_dtype,
            act_enum,
            act_len,
            scale_dtype,
            act_params,
            amax_scope,
            transpose_batch_sequence,
            output_amax_when_no_scaling,
            is_outer,
            mesh,
            result_types,
        )
588
        prefix = "ActLu"
589
590
591
        input_shape = value_types[0].shape
        output_shape = input_shape[:-2] + input_shape[-1:]
        # Here we pass len of output so that the scales are propagated correctly
592
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
593
            output_shape, unique_var=prefix, flatten_axis=-1, q_layout=quantize_layout
594
        )
595
596
597
598
599
        # Correct the input spec with act dim
        input_spec = scale_rules.input_spec
        input_spec = input_spec[:-1] + (prefix + "_act_dim",) + input_spec[-1:]
        amax = (BATCHING + prefix + "_amax",)
        scale = (BATCHING + prefix + "_scale",)
600
601

        return SdyShardingRule(
602
            (tuple(input_spec), scale, amax),
603
            (
604
605
606
607
                scale_rules.rowwise_out_spec,
                scale_rules.colwise_out_spec,
                scale_rules.rowwise_scale_spec,
                scale_rules.colwise_scale_spec,
608
                amax,
609
            ),
610
            **scale_rules.factor_sizes,
611
612
        )

613

614
register_primitive(ActLuPrimitive)
615
616


617
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
618
    """
619
    DActLu DBias Cast Transpose Primitive
620
    """
621

622
623
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
624
    # out_dtype, scaling_mode, quantize_layout, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
625
    impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
626
627
628
629
    inner_primitive = None
    outer_primitive = None

    @staticmethod
630
631
632
633
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
634
        amax_aval,
635
636
637
        *,
        out_dtype,
        scaling_mode,
638
        quantize_layout,
639
640
641
642
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
643
        act_params,
644
645
646
647
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
648
    ):
649
        """
650
        te_dact_dbias_quantize_p abstract
651
        """
652
        del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
653
654
655
656
657
658
659
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
660
        assert scale_aval.dtype == jnp.float32
661
        assert amax_aval.dtype == jnp.float32
662
663
664
665
666
667

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

668
        ir_hidden_size = dz_aval.shape[-1]
669
        gi_hidden_size = act_len * x_aval.shape[-1]
670
        assert act_len * ir_hidden_size == gi_hidden_size
671
672
673
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
674
675
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
676

677
678
679
680
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
681
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
682
        if quantize_layout.is_rowwise_colwise:
683
            if ScalingMode(scaling_mode).is_tensor_scaling():
684
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
685
            else:
686
687
688
689
690
691
692
693
694
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
695

696
        if is_dbias:
697
            dbias_shape = (act_len, ir_hidden_size)
698
699
700
701
702
703
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
704
                quantize_layout.value,
705
            )
706
707
708
709
710
711
712
713
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
714

715
716
717
718
719
720
721
722
723
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
724
725

    @staticmethod
726
    def outer_abstract(*args, **kwargs):
727
        """
728
        te_dact_dbias_quantize_p outer abstract
729
        """
730
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
731
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
732
733
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
734
735

    @staticmethod
736
737
738
739
740
    def lowering(
        ctx,
        dz,
        x,
        scale,
741
        amax,
742
743
744
        *,
        out_dtype,
        scaling_mode,
745
        quantize_layout,
746
747
748
749
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
750
        act_params,
751
752
753
754
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
755
    ):
756
        """
757
        te_dact_dbias_quantize_p lowering rules
758
        """
759
760
761
762
763
764
765
766
767
        del (
            out_dtype,
            scale_dtype,
            act_len,
            is_outer,
            amax_scope,
            transpose_batch_sequence,
        )
        dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in
768
769
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
770
771
772
773
774
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDActLuDBiasQuantizePrimitive.name,
            operand_output_aliases={3: 4},  # donate amax buffer to updated_amax
        )(
775
776
777
778
            ctx,
            dz,
            x,
            scale,
779
            amax,
780
            scaling_mode=scaling_mode.value,
781
            quantize_layout=quantize_layout.value.value,
782
783
            is_dbias=is_dbias,
            act_enum=int(act_enum),
784
            act_params=act_params.to_ffi_lowering_dict(),
785
            output_amax_when_no_scaling=output_amax_when_no_scaling,
786
        )
787
788

    @staticmethod
789
790
791
792
    def impl(
        dz,
        x,
        scale,
793
        amax,
794
795
        out_dtype,
        scaling_mode,
796
        quantize_layout,
797
798
799
800
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
801
        act_params,
802
803
804
805
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
806
    ):
807
        """
808
        te_dact_dbias_quantize_p impl
809
        """
810
        del is_outer
811
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
812
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
813
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
814
815
816
                dz,
                x,
                scale,
817
                amax,
818
819
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
820
                quantize_layout=quantize_layout,
821
822
823
824
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
825
                act_params=act_params,
826
827
828
829
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=False,
830
831
832
833
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
834
835
836
837
838
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
839
        if quantize_layout.is_rowwise_colwise:
840
841
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
842
            )
843
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
844
845

    @staticmethod
846
847
848
849
850
851
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
852
        quantize_layout,
853
854
855
856
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
857
        act_params,
858
859
860
861
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
862
    ):
863
        """
864
        to describe batch rules for vmap
865
        """
866
        check_valid_batch_dims(batch_dims)
867
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
868
869
        dz, x, scale, amax = batched_args
        _, x_bdim, scale_bdim, _ = batch_dims
870
871
872
873
874
875
876
877
878
879

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
880
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
881
882
883
                dz,
                x,
                scale,
884
                amax,
885
886
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
887
                quantize_layout=quantize_layout,
888
889
890
891
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
892
                act_params=act_params,
893
894
895
896
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=is_outer,
897
898
899
            ),
            out_bdims,
        )
900
901

    @staticmethod
902
903
904
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
905
        quantize_layout,
906
907
908
909
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
910
        act_params,
911
912
913
914
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
915
916
917
918
        mesh,
        arg_infos,
        result_infos,
    ):
919
920
921
        del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling
        del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence

922
        x_spec = get_padded_spec(arg_infos[1])
923
        scale_spec = get_padded_spec(arg_infos[2])
924

925
926
927
928
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

929
        out_sharding = NamedSharding(
930
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
931
        )
932
        if quantize_layout.is_rowwise_colwise:
933
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
934
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
935
936
937
938
939
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
940
941
942
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
943
944
        )

945
946
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
947
            mesh,
948
            PartitionSpec(*dbias_spec),
949
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
950
        )
951
952

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
953
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
954
            scale_inv_spec = amax_spec = scale_spec
955
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
956
957
            scale_inv_spec = x_spec

958
        if quantize_layout.is_rowwise_colwise:
959
960
            colwise_scale_inv_spec = scale_inv_spec

961
        scale_inv_sharding = NamedSharding(
962
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
963
964
        )
        amax_sharding = NamedSharding(
965
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
966
        )
967
968
969
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
970
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
971
972
973
974
975
976
977
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
978
            dbias_sharding,
979
980
981
982
983
984
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
985
        quantize_layout,
986
987
988
989
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
990
        act_params,
991
992
993
994
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
995
996
997
998
999
1000
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
1001
1002
1003
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
1004
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
1005
1006
        )

1007
        if quantize_layout.is_rowwise_colwise:
1008
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1009
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
1010
1011
1012
1013
1014
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
1015
1016
1017
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
1018
1019
        )

1020
1021
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
1022
            mesh,
1023
            PartitionSpec(*dbias_spec),
1024
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
1025
        )
1026
1027

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
1028
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1029
            scale_inv_spec = amax_spec = scale_spec
1030
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
1031
1032
            scale_inv_spec = x_spec

1033
        if quantize_layout.is_rowwise_colwise:
1034
1035
            colwise_scale_inv_spec = scale_inv_spec

1036
        scale_inv_sharding = NamedSharding(
1037
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
1038
        )
1039
1040
1041
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
1042
1043
        )

1044
1045
1046
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
1047
1048
1049
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
1050
1051
        )
        arg_shardings = tuple(arg_shardings)
1052
1053
1054
1055
1056
1057
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
1058
            dbias_sharding,
1059
        )
1060

1061
1062
        def sharded_impl(dz, x, scale, amax):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = (
1063
                BaseDActLuDBiasQuantizePrimitive.impl(
1064
1065
1066
                    dz,
                    x,
                    scale,
1067
                    amax,
1068
1069
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
1070
                    quantize_layout=quantize_layout,
1071
1072
1073
1074
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
1075
                    act_params=act_params,
1076
1077
1078
1079
                    output_amax_when_no_scaling=output_amax_when_no_scaling,
                    amax_scope=amax_scope,
                    transpose_batch_sequence=transpose_batch_sequence,
                    is_outer=True,
1080
1081
1082
1083
1084
1085
1086
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

1087
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1088
1089
1090
1091
1092
1093
1094
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(
                    local_updated_amax, mesh
                )
            elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
                global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                    local_updated_amax, x_spec, transpose_batch_sequence, mesh
                )
1095
            else:
1096
                global_updated_amax = local_updated_amax
1097
1098

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
1099
1100
1101

        return mesh, sharded_impl, out_shardings, arg_shardings

1102
1103
1104
1105
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
1106
        quantize_layout,
1107
1108
1109
1110
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
1111
        act_params,
1112
1113
1114
1115
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
1116
1117
1118
1119
        mesh,
        value_types,
        result_types,
    ):
1120

1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
        del (
            out_dtype,
            scale_dtype,
            act_enum,
            act_len,
            act_params,
            is_outer,
            output_amax_when_no_scaling,
            mesh,
            result_types,
            amax_scope,
            transpose_batch_sequence,
        )

1135
        prefix = "DActLuDBias_"
1136
        # get sharding rules base on the input shape
1137
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
1138
1139
1140
1141
            value_types[1].shape,
            unique_var=prefix,
            flatten_axis=-2,
            q_layout=quantize_layout,
1142
        )
1143

1144
1145
1146
1147
1148
        input_spec = scale_rules.input_spec
        dz_spec = (*input_spec[:-2], input_spec[-1])
        dbias = input_spec[-2:] if is_dbias else (prefix + "_dbias",)
        amax = (prefix + "_amax",)
        scale = (prefix + "_scale",)
1149
1150

        return SdyShardingRule(
1151
1152
1153
1154
1155
1156
1157
1158
1159
            (tuple(dz_spec), tuple(input_spec), scale, amax),
            (
                scale_rules.rowwise_out_spec,
                scale_rules.colwise_out_spec,
                scale_rules.rowwise_scale_spec,
                scale_rules.colwise_scale_spec,
                amax,
                dbias,
            ),
1160
            **scale_rules.factor_sizes,
1161
1162
        )

1163

1164
1165
1166
1167
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1168
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1169
1170
1171


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1172
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1173
1174


1175
1176
1177
def _jax_act_lu(
    inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None
) -> Union[NoScaleTensor, ScaledTensor]:
1178
    """
1179
    JAX native activation implementation
1180
    """
1181
    act_params = act_params if act_params is not None else ActivationParams()
1182
1183
1184
1185
1186
1187
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )
    x = jnp.split(inputs, act_len, axis=-2)
1188
1189
    acts = []
    for idx, act_fn in enumerate(activation_type):
1190
        x_i = _convert_to_activation_function(act_fn, act_params)(x[idx])
1191
1192
        acts.append(x_i)
    x = reduce(operator.mul, acts)
1193
    x = jnp.squeeze(x, axis=-2)
1194
    if quantizer:
1195
        return quantizer.quantize(x, flatten_axis=-1)
1196
    return NoScaleTensor(data=x, amax=None)
1197
1198


1199
def _jax_quantize_dact_dbias(
1200
    dz: Union[jnp.ndarray, NoScaleTensor],
1201
1202
1203
1204
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1205
    act_params: Optional[ActivationParams] = None,
1206
):
1207
    """
1208
    JAX implementation of dact_lu and dbias with optional quantization
1209
    """
1210
    act_params = act_params if act_params is not None else ActivationParams()
1211
1212
1213
1214
1215
1216
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1217
    _, vjp_func = jax.vjp(
1218
1219
        partial(_jax_act_lu, activation_type=activation_type, act_params=act_params),
        x.astype(jnp.float32),
1220
    )
1221
1222
1223
    # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
    dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
    (dx,) = vjp_func(dz)
1224

1225
1226
    dbias = None
    if is_dbias:
1227
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
1228

1229
    if quantizer is not None:
1230
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
1231
1232
    else:
        dx = dx.astype(x.dtype)
1233
        dx = NoScaleTensor(data=dx, amax=None)
1234

1235
    return dx, dbias
1236
1237


1238
1239
1240
1241
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
1242
    act_params: Optional[ActivationParams] = None,
1243
    amax_scope: AmaxScope = AmaxScope.LOCAL,
1244
1245
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1246
1247
1248
1249
1250
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
1251
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1252
1253
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
1254
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
1255
1256
1257
1258
1259
1260
1261

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
1262
1263
    # TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl()
    # Do the same with dact_dbias_quantize() and layernorm_fwd()
1264
    act_type_id = ActivationEnum[activation_type].value
1265
1266
1267
1268
1269
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1270
    act_params = act_params if act_params is not None else ActivationParams()
1271
    if not ActLuPrimitive.enabled():
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
        act_out = _jax_act_lu(x, activation_type, act_params=act_params)
        assert (
            act_out.data.dtype == x.dtype
        ), f"JAX activation output dtype {act_out.data.dtype} must match input dtype {x.dtype}"
        if quantizer is None:
            return act_out

        return quantize(
            act_out,
            quantizer=quantizer,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
        )
1285

1286
    # TE/common does not support colwise-only quantization yet
1287
    if quantizer is not None and quantizer.q_layout.is_colwise_only:
1288
        return _jax_act_lu(x, activation_type, quantizer, act_params)
1289
1290
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
1291
1292
1293
1294
1295
1296
1297
1298
        f=act_lu,
        x=x,
        activation_type=activation_type,
        quantizer=quantizer,
        act_params=act_params,
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
1299
1300
1301
1302
1303
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1304
    output_shape = (*x.shape[:-2], x.shape[-1])
1305
1306
    amax = jnp.zeros((1,), jnp.float32)  # need to init with zero and shape=(1,)

1307
    if quantizer is None:
1308
        out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind(
1309
1310
            x,
            scale,
1311
            amax,
1312
1313
            out_dtype=x.dtype,
            act_enum=act_type_id,
1314
            act_len=act_len,
1315
            scaling_mode=ScalingMode.NO_SCALING.value,
1316
            quantize_layout=QuantizeLayout.ROWWISE,
1317
            scale_dtype=jnp.float32,
1318
            act_params=act_params,
1319
1320
1321
1322
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
            is_outer=True,
1323
        )
1324
        out = out.reshape(output_shape)
1325
        # TODO(Phuong): ScaledTensorFactory to create NoScaledTensor
1326
1327
        out = NoScaleTensor(
            data=out,
1328
            amax=updated_amax if output_amax_when_no_scaling else None,
1329
        )
1330
        return out
1331

1332
1333
1334
1335
    if (
        quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
        or quantizer.scaling_mode.is_nvfp4_scaling
    ):
1336
1337
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
1338
            x=x,
1339
1340
            activation_type=activation_type,
            quantizer=None,
1341
            act_params=act_params,
1342
1343
1344
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=True,
1345
        )
1346
1347
1348
1349
        assert (
            out.data.dtype == x.dtype
        ), f"Activation output dtype {out.data.dtype} must match input dtype {x.dtype}"
        out = quantize(
1350
1351
1352
            out,
            quantizer=quantizer,
            amax_scope=amax_scope,
1353
            transpose_batch_sequence=transpose_batch_sequence,
1354
        )
1355
        return out
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
1368
        amax,
1369
1370
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1371
        act_len=act_len,
1372
        scaling_mode=quantizer.scaling_mode.value,
1373
        quantize_layout=quantizer.q_layout,
1374
        scale_dtype=quantizer.get_scale_dtype(),
1375
        act_params=act_params,
1376
1377
1378
1379
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
        is_outer=True,
1380
    )
1381

1382
1383
1384
1385
1386
1387
1388
1389
1390
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1391
1392
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1393
    )
1394
1395


1396
1397
1398
1399
1400
1401
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1402
    act_params: Optional[ActivationParams] = None,
1403
1404
1405
    amax_scope: AmaxScope = AmaxScope.LOCAL,
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1406
1407
1408
1409
1410
1411
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1412
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1413
1414
1415
1416
1417
1418
1419
1420
1421
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1422
    act_params = act_params if act_params is not None else ActivationParams()
1423
1424
1425
1426
1427
1428
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1429
1430
    scale = jnp.empty((1,), jnp.float32)
    amax = jnp.zeros((1,), jnp.float32)  # need to init with zero and shape=(1,)
Alp Dener's avatar
Alp Dener committed
1431
    act_type_id = ActivationEnum[activation_type]
1432
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1433
    if not PrimitiveClass.enabled() or (
1434
        quantizer is not None and quantizer.q_layout.is_colwise_only
Alp Dener's avatar
Alp Dener committed
1435
    ):
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        if quantizer is None:
            return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, act_params=act_params)
        dact_out, _ = _jax_quantize_dact_dbias(
            dz, x, activation_type, is_dbias=False, act_params=act_params
        )
        assert (
            dact_out.data.dtype == x.dtype
        ), f"JAX dact output dtype {dact_out.data.dtype} must match input dtype {x.dtype}"
        return quantize_dbias(
            dact_out,
            quantizer,
            is_dbias=is_dbias,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
        )

Alp Dener's avatar
Alp Dener committed
1453
    if quantizer is None:
1454
        output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind(
Alp Dener's avatar
Alp Dener committed
1455
1456
1457
            dz,
            x,
            scale,
1458
            amax,
Alp Dener's avatar
Alp Dener committed
1459
1460
1461
1462
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
1463
            quantize_layout=QuantizeLayout.ROWWISE,  # unused
Alp Dener's avatar
Alp Dener committed
1464
1465
1466
1467
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
1468
            act_params=act_params,
1469
1470
1471
1472
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
            is_outer=True,
Alp Dener's avatar
Alp Dener committed
1473
1474
1475
1476
1477
1478
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

1479
1480
        output = NoScaleTensor(
            data=output,
1481
            amax=updated_amax if output_amax_when_no_scaling else None,
1482
        )
Alp Dener's avatar
Alp Dener committed
1483
        return output, dbias
1484

1485
1486
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1487
        out = dact_lu(
1488
1489
1490
1491
1492
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type,
            quantizer=None,
            act_params=act_params,
1493
1494
1495
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
1496
1497
        )
        return _quantize_dbias_impl(
1498
            out,
1499
1500
1501
1502
1503
1504
            quantizer,
            is_dbias=True,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1505
        )
1506

1507
    is_gated = act_len == 2
1508
1509
1510
1511
1512
1513
1514
1515
1516
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1517
            flatten_axis=-2,
1518
            act_params=act_params,
1519
1520
1521
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
1522
1523
1524
1525
        )
        if war_output is not None:
            return war_output

1526
1527
1528
1529
    if (
        quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
        or quantizer.scaling_mode.is_nvfp4_scaling
    ):
1530
1531
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
1532
1533
            dz=dz,
            x=x,
1534
1535
            activation_type=activation_type,
            quantizer=None,
1536
            act_params=act_params,
1537
1538
1539
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=True,
1540
1541
        )
        out, dbias = _quantize_dbias_impl(
1542
1543
1544
1545
1546
1547
1548
            out,
            is_dbias=is_dbias,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1549
1550
1551
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1552
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1553
1554
1555
1556
1557
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
1558
1559
1560
1561
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type=activation_type,
            act_params=act_params,
1562
1563
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1564
        )
1565
        out, dbias = _quantize_dbias_impl(
1566
1567
1568
1569
1570
1571
1572
            dgated,
            quantizer,
            is_dbias=True,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1573
        )
1574
1575
1576
1577
1578
1579
1580
1581
1582
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1583
    ) = PrimitiveClass.outer_primitive.bind(
1584
1585
1586
        dz,
        x,
        scale,
1587
        amax,
1588
1589
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
1590
        quantize_layout=quantizer.q_layout,
1591
1592
1593
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1594
        act_len=act_len,
1595
        act_params=act_params,
1596
1597
1598
1599
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
        is_outer=True,
1600
    )
1601

1602
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1603
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1615
1616
1617
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1618
    )
1619

1620
    return out, dbias
1621
1622


1623
1624
def dact_lu(
    dz: jnp.ndarray,
1625
1626
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1627
    quantizer: Optional[Quantizer] = None,
1628
    act_params: Optional[ActivationParams] = None,
1629
1630
1631
    amax_scope: AmaxScope = AmaxScope.LOCAL,
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1632
) -> Union[jnp.ndarray, ScaledTensor]:
1633
    """
1634
    Backward pass for activation with optional quantization.
1635

1636
1637
1638
1639
1640
1641
1642
1643
1644
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
1645
    act_params = act_params if act_params is not None else ActivationParams()
1646
1647
1648
1649
1650
1651
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1652
        act_params=act_params,
1653
1654
1655
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
1656
    )
1657
    return output