activation.py 56.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from dataclasses import dataclass
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
14
from jax.sharding import PartitionSpec
15

16
import numpy as np
17
18
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
19
20
21
from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, quantize, quantize_dbias, _quantize_dbias_impl, AmaxScope
31
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
32
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
33
34
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40

41
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
42
43
44


ActivationEnum = {
45
46
47
48
49
50
51
52
53
54
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
55
    ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU,
56
57
58
}


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@dataclass(frozen=True)
class ClampedSwigluParams:
    """Parameters for the Clamped SwiGLU activation function
    used in GPT OSS."""

    limit: float = 7.0
    alpha: float = 1.702

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work.

        Returns:
            int: Hash value of the dataclass instance.
        """
        return hash((self.limit, self.alpha))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.

        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}


@dataclass(frozen=True)
class ActivationParams:
    """Parameters for various activation functions.
    Currently only Clamped SwiGLU activation has parameters.
    """

    clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()

    @staticmethod
    def create(activation_type, **kwargs):
        """Factory method to create ActivationParams based on activation_type."""
        CLAMPED_ACTIVATION_TYPES = {
            ("clamped_silu", "clamped_linear"),
            "clamped_silu",
            "clamped_linear",
        }
        if activation_type in CLAMPED_ACTIVATION_TYPES:
            return ActivationParams(ClampedSwigluParams(**kwargs))
        return ActivationParams()  # Default params for activations without parameters

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work"""
        return hash((self.clamped_swiglu,))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.
        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}


def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
119
120
121
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
122
123
124
125
126
127
    if fn_or_string == "clamped_linear":
        # This function is used for ClampedSwiGLU
        # used in GPT OSS where the gates are not only clamped
        # but also shifted by +1
        limit = act_params.clamped_swiglu.limit
        return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
128
129
130
131
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
132
133
134
135
    if fn_or_string == "clamped_silu":
        limit = act_params.clamped_swiglu.limit
        alpha = act_params.clamped_swiglu.alpha
        return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit)
136
137
138
139
140
141
142
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


143
144
class ActLuPrimitive(BasePrimitive):
    """
145
    ActLu Primitive
146
    """
147

148
149
150
151
152
153
154
155
156
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
157
        9,
158
159
160
161
        10,
        11,
        12,
        13,
162
    )  # out_dtype, act_enum, act_len, scaling_mode, quantize_layout, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
163
164
165
166
    inner_primitive = None
    outer_primitive = None

    @staticmethod
167
168
169
    def abstract(
        x_aval,
        scale_aval,
170
        amax_aval,
171
172
173
174
175
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
176
        quantize_layout,
177
        scale_dtype,
178
        act_params,
179
180
181
182
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
183
    ):
184
        """
185
        te_act_lu_p abstract
186
        """
187
188
189
190
        del act_enum, act_params, amax_scope, transpose_batch_sequence
        assert (
            not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value
        ), f"scaling_mode = {scaling_mode}"
191
192
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
193
        assert scale_aval is None or scale_aval.dtype == jnp.float32
194
        assert amax_aval is None or amax_aval.dtype == jnp.float32
195
196
197
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
198
        )
199

200
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
201
202
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
203
        )
204
205
206
207
208
209
210
        assert not ScalingMode(scaling_mode).is_nvfp4_scaling, (
            "NVFP4 block scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
        )
        assert (
            not quantize_layout.is_colwise_only
        ), "Fused activation with colwise-only quantization is not supported."
211

212
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
213
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
214

215
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
216

217
218
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
219
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
220
        if quantize_layout.is_rowwise_only:
221
222
223
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
224
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
225
226
227
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
228
229

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
230
231

    @staticmethod
232
233
234
235
    def lowering(
        ctx,
        x,
        scale,
236
        amax,
237
238
239
240
241
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
242
        quantize_layout,
243
        scale_dtype,
244
        act_params,
245
246
247
248
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
249
    ):
250
        """
251
        te_gated_act_lu_p lowering rules
252
        """
253
254
        del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
        x_aval, scale_aval, amax_aval = ctx.avals_in
255
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
256
        assert scale_aval is None or scale_aval.dtype == jnp.float32
257
258
259
260
261
262
        assert amax_aval.dtype == jnp.float32

        out = ffi.ffi_lowering(
            ActLuPrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
263
264
265
            ctx,
            x,
            scale,
266
            amax,
267
268
            act_enum=act_enum,
            scaling_mode=scaling_mode.value,
269
            quantize_layout=quantize_layout.value.value,
270
            act_params=act_params.to_ffi_lowering_dict(),
271
            output_amax_when_no_scaling=output_amax_when_no_scaling,
272
        )
273
        return out
274
275

    @staticmethod
276
277
278
    def impl(
        x,
        scale,
279
        amax,
280
281
282
283
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
284
        quantize_layout,
285
        scale_dtype,
286
        act_params,
287
288
289
290
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
291
292
293
294
295
    ):
        """
        to describe implementation
        """
        del is_outer
296
        assert ActLuPrimitive.inner_primitive is not None
297
298
299
300
301

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
302
                amax,
303
304
305
306
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
307
                quantize_layout=quantize_layout,
308
                scale_dtype=scale_dtype,
309
                act_params=act_params,
310
311
312
313
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=False,
314
315
316
317
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
318
319
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
320
321
322
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
323
        if quantize_layout.is_rowwise_colwise:
324
325
326
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
327

328
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
329
330

    @staticmethod
331
332
333
334
335
336
337
338
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
339
        quantize_layout,
340
        scale_dtype,
341
        act_params,
342
343
344
345
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
346
    ):
347
        """
348
        to describe batch rules for vmap
349
350
351
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
352
353
        x, scale, amax = batched_args
        x_bdim, scale_bdim, _ = batch_dims
354
        amax_bdim = scale_bdim
355

356
357
358
359
360
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
361
                amax,
362
363
                out_dtype=out_dtype,
                act_enum=act_enum,
364
                act_len=act_len,
365
                scaling_mode=scaling_mode,
366
                quantize_layout=quantize_layout,
367
                scale_dtype=scale_dtype,
368
                act_params=act_params,
369
370
371
372
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=is_outer,
373
374
375
            ),
            out_bdims,
        )
376
377

    @staticmethod
378
379
380
381
382
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
383
        quantize_layout,
384
        scale_dtype,
385
        act_params,
386
387
388
389
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
390
391
392
393
394
395
396
397
398
399
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
400
            act_params,
401
402
403
404
            amax_scope,
            transpose_batch_sequence,
            output_amax_when_no_scaling,
            is_outer,
405
        )  # Unused.
406
        x_spec = get_padded_spec(arg_infos[0])
407
408
409
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
410
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
411

412
        if quantize_layout.is_rowwise_colwise:
413
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
414
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
415
416
417
418
419
420
421
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
422
423

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
424
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
425
            scale_inv_spec = amax_spec = scale_spec
426
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
427
428
            scale_inv_spec = out_spec

429
        if quantize_layout.is_rowwise_colwise:
430
431
            colwise_scale_inv_spec = scale_inv_spec

432
        scale_inv_sharding = NamedSharding(
433
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
434
        )
435
436
437
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
438
        )
439

440
441
442
443
444
445
446
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
447
448

    @staticmethod
449
450
451
452
453
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
454
        quantize_layout,
455
        scale_dtype,
456
        act_params,
457
458
459
460
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
461
462
463
464
        mesh,
        arg_infos,
        result_infos,
    ):
465
        del result_infos, is_outer
466
        x_spec = get_padded_spec(arg_infos[0])
467
468
469
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
470
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
471

472
        if quantize_layout.is_rowwise_colwise:
473
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
474
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
475
476
477
478
479
480
481
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
482
483

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
484
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
485
            scale_inv_spec = amax_spec = scale_spec
486
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
487
488
            scale_inv_spec = out_spec

489
490
491
492
        if quantize_layout.is_rowwise_colwise:
            assert not ScalingMode(
                scaling_mode
            ).is_colwise_transposed, "Transpose layout scaling modes are not supported here yet"
493
494
            colwise_scale_inv_spec = scale_inv_spec

495
        scale_inv_sharding = NamedSharding(
496
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
497
        )
498
499
500
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
501
        )
502
503

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
504
505
506
507
508
509
510
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        def sharded_impl(x, scale, amax):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_updated_amax,
            ) = ActLuPrimitive.impl(
                x,
                scale,
                amax,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
527
                quantize_layout=quantize_layout,
528
529
530
531
532
533
                scale_dtype=scale_dtype,
                act_params=act_params,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=True,
534
            )
535

536
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
537
538
539
540
541
542
543
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(
                    local_updated_amax, mesh
                )
            elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
                global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                    local_updated_amax, out_spec, transpose_batch_sequence, mesh
                )
544
            else:
545
                global_updated_amax = local_updated_amax
546
547
548
549
550
551
552
553

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
554

555
        return mesh, sharded_impl, out_shardings, arg_shardings
556

557
558
559
560
561
562
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
563
        quantize_layout,
564
        scale_dtype,
565
        act_params,
566
567
568
569
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
570
571
572
573
        mesh,
        value_types,
        result_types,
    ):
574
575
576
577
578
579
580
581
582
583
584
585
586
        del (
            out_dtype,
            act_enum,
            act_len,
            scale_dtype,
            act_params,
            amax_scope,
            transpose_batch_sequence,
            output_amax_when_no_scaling,
            is_outer,
            mesh,
            result_types,
        )
587
        prefix = "ActLu"
588
589
590
        input_shape = value_types[0].shape
        output_shape = input_shape[:-2] + input_shape[-1:]
        # Here we pass len of output so that the scales are propagated correctly
591
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
592
            output_shape, unique_var=prefix, flatten_axis=-1, q_layout=quantize_layout
593
        )
594
595
596
597
598
        # Correct the input spec with act dim
        input_spec = scale_rules.input_spec
        input_spec = input_spec[:-1] + (prefix + "_act_dim",) + input_spec[-1:]
        amax = (BATCHING + prefix + "_amax",)
        scale = (BATCHING + prefix + "_scale",)
599
600

        return SdyShardingRule(
601
            (tuple(input_spec), scale, amax),
602
            (
603
604
605
606
                scale_rules.rowwise_out_spec,
                scale_rules.colwise_out_spec,
                scale_rules.rowwise_scale_spec,
                scale_rules.colwise_scale_spec,
607
                amax,
608
            ),
609
            **scale_rules.factor_sizes,
610
611
        )

612

613
register_primitive(ActLuPrimitive)
614
615


616
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
617
    """
618
    DActLu DBias Cast Transpose Primitive
619
    """
620

621
622
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
623
    # out_dtype, scaling_mode, quantize_layout, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
624
    impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
625
626
627
628
    inner_primitive = None
    outer_primitive = None

    @staticmethod
629
630
631
632
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
633
        amax_aval,
634
635
636
        *,
        out_dtype,
        scaling_mode,
637
        quantize_layout,
638
639
640
641
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
642
        act_params,
643
644
645
646
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
647
    ):
648
        """
649
        te_dact_dbias_quantize_p abstract
650
        """
651
        del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
652
653
654
655
656
657
658
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
659
        assert scale_aval.dtype == jnp.float32
660
        assert amax_aval.dtype == jnp.float32
661
662
663
664
665
666

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

667
        ir_hidden_size = dz_aval.shape[-1]
668
        gi_hidden_size = act_len * x_aval.shape[-1]
669
        assert act_len * ir_hidden_size == gi_hidden_size
670
671
672
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
673
674
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
675

676
677
678
679
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
680
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
681
        if quantize_layout.is_rowwise_colwise:
682
            if ScalingMode(scaling_mode).is_tensor_scaling():
683
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
684
            else:
685
686
687
688
689
690
691
692
693
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
694

695
        if is_dbias:
696
            dbias_shape = (act_len, ir_hidden_size)
697
698
699
700
701
702
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
703
                quantize_layout.value,
704
            )
705
706
707
708
709
710
711
712
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
713

714
715
716
717
718
719
720
721
722
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
723
724

    @staticmethod
725
    def outer_abstract(*args, **kwargs):
726
        """
727
        te_dact_dbias_quantize_p outer abstract
728
        """
729
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
730
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
731
732
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
733
734

    @staticmethod
735
736
737
738
739
    def lowering(
        ctx,
        dz,
        x,
        scale,
740
        amax,
741
742
743
        *,
        out_dtype,
        scaling_mode,
744
        quantize_layout,
745
746
747
748
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
749
        act_params,
750
751
752
753
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
754
    ):
755
        """
756
        te_dact_dbias_quantize_p lowering rules
757
        """
758
759
760
761
762
763
764
765
766
        del (
            out_dtype,
            scale_dtype,
            act_len,
            is_outer,
            amax_scope,
            transpose_batch_sequence,
        )
        dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in
767
768
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
769
770
771
772
773
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDActLuDBiasQuantizePrimitive.name,
            operand_output_aliases={3: 4},  # donate amax buffer to updated_amax
        )(
774
775
776
777
            ctx,
            dz,
            x,
            scale,
778
            amax,
779
            scaling_mode=scaling_mode.value,
780
            quantize_layout=quantize_layout.value.value,
781
782
            is_dbias=is_dbias,
            act_enum=int(act_enum),
783
            act_params=act_params.to_ffi_lowering_dict(),
784
            output_amax_when_no_scaling=output_amax_when_no_scaling,
785
        )
786
787

    @staticmethod
788
789
790
791
    def impl(
        dz,
        x,
        scale,
792
        amax,
793
794
        out_dtype,
        scaling_mode,
795
        quantize_layout,
796
797
798
799
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
800
        act_params,
801
802
803
804
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
805
    ):
806
        """
807
        te_dact_dbias_quantize_p impl
808
        """
809
        del is_outer
810
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
811
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
812
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
813
814
815
                dz,
                x,
                scale,
816
                amax,
817
818
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
819
                quantize_layout=quantize_layout,
820
821
822
823
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
824
                act_params=act_params,
825
826
827
828
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=False,
829
830
831
832
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
833
834
835
836
837
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
838
        if quantize_layout.is_rowwise_colwise:
839
840
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
841
            )
842
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
843
844

    @staticmethod
845
846
847
848
849
850
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
851
        quantize_layout,
852
853
854
855
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
856
        act_params,
857
858
859
860
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
861
    ):
862
        """
863
        to describe batch rules for vmap
864
        """
865
        check_valid_batch_dims(batch_dims)
866
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
867
868
        dz, x, scale, amax = batched_args
        _, x_bdim, scale_bdim, _ = batch_dims
869
870
871
872
873
874
875
876
877
878

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
879
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
880
881
882
                dz,
                x,
                scale,
883
                amax,
884
885
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
886
                quantize_layout=quantize_layout,
887
888
889
890
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
891
                act_params=act_params,
892
893
894
895
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=is_outer,
896
897
898
            ),
            out_bdims,
        )
899
900

    @staticmethod
901
902
903
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
904
        quantize_layout,
905
906
907
908
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
909
        act_params,
910
911
912
913
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
914
915
916
917
        mesh,
        arg_infos,
        result_infos,
    ):
918
919
920
        del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling
        del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence

921
        x_spec = get_padded_spec(arg_infos[1])
922
        scale_spec = get_padded_spec(arg_infos[2])
923

924
925
926
927
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

928
        out_sharding = NamedSharding(
929
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
930
        )
931
        if quantize_layout.is_rowwise_colwise:
932
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
933
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
934
935
936
937
938
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
939
940
941
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
942
943
        )

944
945
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
946
            mesh,
947
            PartitionSpec(*dbias_spec),
948
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
949
        )
950
951

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
952
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
953
            scale_inv_spec = amax_spec = scale_spec
954
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
955
956
            scale_inv_spec = x_spec

957
        if quantize_layout.is_rowwise_colwise:
958
959
            colwise_scale_inv_spec = scale_inv_spec

960
        scale_inv_sharding = NamedSharding(
961
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
962
963
        )
        amax_sharding = NamedSharding(
964
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
965
        )
966
967
968
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
969
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
970
971
972
973
974
975
976
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
977
            dbias_sharding,
978
979
980
981
982
983
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
984
        quantize_layout,
985
986
987
988
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
989
        act_params,
990
991
992
993
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
994
995
996
997
998
999
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
1000
1001
1002
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
1003
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
1004
1005
        )

1006
        if quantize_layout.is_rowwise_colwise:
1007
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1008
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
1009
1010
1011
1012
1013
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
1014
1015
1016
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
1017
1018
        )

1019
1020
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
1021
            mesh,
1022
            PartitionSpec(*dbias_spec),
1023
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
1024
        )
1025
1026

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
1027
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1028
            scale_inv_spec = amax_spec = scale_spec
1029
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
1030
1031
            scale_inv_spec = x_spec

1032
        if quantize_layout.is_rowwise_colwise:
1033
1034
            colwise_scale_inv_spec = scale_inv_spec

1035
        scale_inv_sharding = NamedSharding(
1036
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
1037
        )
1038
1039
1040
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
1041
1042
        )

1043
1044
1045
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
1046
1047
1048
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
1049
1050
        )
        arg_shardings = tuple(arg_shardings)
1051
1052
1053
1054
1055
1056
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
1057
            dbias_sharding,
1058
        )
1059

1060
1061
        def sharded_impl(dz, x, scale, amax):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = (
1062
                BaseDActLuDBiasQuantizePrimitive.impl(
1063
1064
1065
                    dz,
                    x,
                    scale,
1066
                    amax,
1067
1068
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
1069
                    quantize_layout=quantize_layout,
1070
1071
1072
1073
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
1074
                    act_params=act_params,
1075
1076
1077
1078
                    output_amax_when_no_scaling=output_amax_when_no_scaling,
                    amax_scope=amax_scope,
                    transpose_batch_sequence=transpose_batch_sequence,
                    is_outer=True,
1079
1080
1081
1082
1083
1084
1085
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

1086
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1087
1088
1089
1090
1091
1092
1093
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(
                    local_updated_amax, mesh
                )
            elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
                global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                    local_updated_amax, x_spec, transpose_batch_sequence, mesh
                )
1094
            else:
1095
                global_updated_amax = local_updated_amax
1096
1097

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
1098
1099
1100

        return mesh, sharded_impl, out_shardings, arg_shardings

1101
1102
1103
1104
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
1105
        quantize_layout,
1106
1107
1108
1109
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
1110
        act_params,
1111
1112
1113
1114
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
1115
1116
1117
1118
        mesh,
        value_types,
        result_types,
    ):
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        del (
            out_dtype,
            scale_dtype,
            act_enum,
            act_len,
            act_params,
            is_outer,
            output_amax_when_no_scaling,
            mesh,
            result_types,
            amax_scope,
            transpose_batch_sequence,
        )

1134
        prefix = "DActLuDBias_"
1135
        # get sharding rules base on the input shape
1136
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
1137
1138
1139
1140
            value_types[1].shape,
            unique_var=prefix,
            flatten_axis=-2,
            q_layout=quantize_layout,
1141
        )
1142

1143
1144
1145
1146
1147
        input_spec = scale_rules.input_spec
        dz_spec = (*input_spec[:-2], input_spec[-1])
        dbias = input_spec[-2:] if is_dbias else (prefix + "_dbias",)
        amax = (prefix + "_amax",)
        scale = (prefix + "_scale",)
1148
1149

        return SdyShardingRule(
1150
1151
1152
1153
1154
1155
1156
1157
1158
            (tuple(dz_spec), tuple(input_spec), scale, amax),
            (
                scale_rules.rowwise_out_spec,
                scale_rules.colwise_out_spec,
                scale_rules.rowwise_scale_spec,
                scale_rules.colwise_scale_spec,
                amax,
                dbias,
            ),
1159
            **scale_rules.factor_sizes,
1160
1161
        )

1162

1163
1164
1165
1166
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1167
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1168
1169
1170


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1171
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1172
1173


1174
1175
1176
def _jax_act_lu(
    inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None
) -> Union[NoScaleTensor, ScaledTensor]:
1177
    """
1178
    JAX native activation implementation
1179
    """
1180
    act_params = act_params if act_params is not None else ActivationParams()
1181
1182
1183
1184
1185
1186
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )
    x = jnp.split(inputs, act_len, axis=-2)
1187
1188
    acts = []
    for idx, act_fn in enumerate(activation_type):
1189
        x_i = _convert_to_activation_function(act_fn, act_params)(x[idx])
1190
1191
        acts.append(x_i)
    x = reduce(operator.mul, acts)
1192
    x = jnp.squeeze(x, axis=-2)
1193
    if quantizer:
1194
        return quantizer.quantize(x, flatten_axis=-1)
1195
    return NoScaleTensor(data=x, amax=None)
1196
1197


1198
def _jax_quantize_dact_dbias(
1199
    dz: Union[jnp.ndarray, NoScaleTensor],
1200
1201
1202
1203
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1204
    act_params: Optional[ActivationParams] = None,
1205
):
1206
    """
1207
    JAX implementation of dact_lu and dbias with optional quantization
1208
    """
1209
    act_params = act_params if act_params is not None else ActivationParams()
1210
1211
1212
1213
1214
1215
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1216
    _, vjp_func = jax.vjp(
1217
1218
        partial(_jax_act_lu, activation_type=activation_type, act_params=act_params),
        x.astype(jnp.float32),
1219
    )
1220
1221
1222
    # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
    dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
    (dx,) = vjp_func(dz)
1223

1224
1225
    dbias = None
    if is_dbias:
1226
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
1227

1228
    if quantizer is not None:
1229
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
1230
1231
    else:
        dx = dx.astype(x.dtype)
1232
        dx = NoScaleTensor(data=dx, amax=None)
1233

1234
    return dx, dbias
1235
1236


1237
1238
1239
1240
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
1241
    act_params: Optional[ActivationParams] = None,
1242
    amax_scope: AmaxScope = AmaxScope.LOCAL,
1243
1244
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1245
1246
1247
1248
1249
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
1250
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1251
1252
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
1253
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
1254
1255
1256
1257
1258
1259
1260

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
1261
1262
    # TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl()
    # Do the same with dact_dbias_quantize() and layernorm_fwd()
1263
    act_type_id = ActivationEnum[activation_type].value
1264
1265
1266
1267
1268
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1269
    act_params = act_params if act_params is not None else ActivationParams()
1270
    if not ActLuPrimitive.enabled():
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
        act_out = _jax_act_lu(x, activation_type, act_params=act_params)
        assert (
            act_out.data.dtype == x.dtype
        ), f"JAX activation output dtype {act_out.data.dtype} must match input dtype {x.dtype}"
        if quantizer is None:
            return act_out

        return quantize(
            act_out,
            quantizer=quantizer,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
        )
1284

1285
    # TE/common does not support colwise-only quantization yet
1286
    if quantizer is not None and quantizer.q_layout.is_colwise_only:
1287
        return _jax_act_lu(x, activation_type, quantizer, act_params)
1288
1289
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
1290
1291
1292
1293
1294
1295
1296
1297
        f=act_lu,
        x=x,
        activation_type=activation_type,
        quantizer=quantizer,
        act_params=act_params,
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
1298
1299
1300
1301
1302
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1303
    output_shape = (*x.shape[:-2], x.shape[-1])
1304
1305
    amax = jnp.zeros((1,), jnp.float32)  # need to init with zero and shape=(1,)

1306
    if quantizer is None:
1307
        out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind(
1308
1309
            x,
            scale,
1310
            amax,
1311
1312
            out_dtype=x.dtype,
            act_enum=act_type_id,
1313
            act_len=act_len,
1314
            scaling_mode=ScalingMode.NO_SCALING.value,
1315
            quantize_layout=QuantizeLayout.ROWWISE,
1316
            scale_dtype=jnp.float32,
1317
            act_params=act_params,
1318
1319
1320
1321
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
            is_outer=True,
1322
        )
1323
        out = out.reshape(output_shape)
1324
        # TODO(Phuong): ScaledTensorFactory to create NoScaledTensor
1325
1326
        out = NoScaleTensor(
            data=out,
1327
            amax=updated_amax if output_amax_when_no_scaling else None,
1328
        )
1329
        return out
1330

1331
1332
1333
1334
    if (
        quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
        or quantizer.scaling_mode.is_nvfp4_scaling
    ):
1335
1336
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
1337
            x=x,
1338
1339
            activation_type=activation_type,
            quantizer=None,
1340
            act_params=act_params,
1341
1342
1343
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=True,
1344
        )
1345
1346
1347
1348
        assert (
            out.data.dtype == x.dtype
        ), f"Activation output dtype {out.data.dtype} must match input dtype {x.dtype}"
        out = quantize(
1349
1350
1351
            out,
            quantizer=quantizer,
            amax_scope=amax_scope,
1352
            transpose_batch_sequence=transpose_batch_sequence,
1353
        )
1354
        return out
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
1367
        amax,
1368
1369
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1370
        act_len=act_len,
1371
        scaling_mode=quantizer.scaling_mode.value,
1372
        quantize_layout=quantizer.q_layout,
1373
        scale_dtype=quantizer.get_scale_dtype(),
1374
        act_params=act_params,
1375
1376
1377
1378
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
        is_outer=True,
1379
    )
1380

1381
1382
1383
1384
1385
1386
1387
1388
1389
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1390
1391
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1392
    )
1393
1394


1395
1396
1397
1398
1399
1400
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1401
    act_params: Optional[ActivationParams] = None,
1402
1403
1404
    amax_scope: AmaxScope = AmaxScope.LOCAL,
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1405
1406
1407
1408
1409
1410
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1411
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1412
1413
1414
1415
1416
1417
1418
1419
1420
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1421
    act_params = act_params if act_params is not None else ActivationParams()
1422
1423
1424
1425
1426
1427
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1428
1429
    scale = jnp.empty((1,), jnp.float32)
    amax = jnp.zeros((1,), jnp.float32)  # need to init with zero and shape=(1,)
Alp Dener's avatar
Alp Dener committed
1430
    act_type_id = ActivationEnum[activation_type]
1431
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1432
    if not PrimitiveClass.enabled() or (
1433
        quantizer is not None and quantizer.q_layout.is_colwise_only
Alp Dener's avatar
Alp Dener committed
1434
    ):
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
        if quantizer is None:
            return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, act_params=act_params)
        dact_out, _ = _jax_quantize_dact_dbias(
            dz, x, activation_type, is_dbias=False, act_params=act_params
        )
        assert (
            dact_out.data.dtype == x.dtype
        ), f"JAX dact output dtype {dact_out.data.dtype} must match input dtype {x.dtype}"
        return quantize_dbias(
            dact_out,
            quantizer,
            is_dbias=is_dbias,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
        )

Alp Dener's avatar
Alp Dener committed
1452
    if quantizer is None:
1453
        output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind(
Alp Dener's avatar
Alp Dener committed
1454
1455
1456
            dz,
            x,
            scale,
1457
            amax,
Alp Dener's avatar
Alp Dener committed
1458
1459
1460
1461
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
1462
            quantize_layout=QuantizeLayout.ROWWISE,  # unused
Alp Dener's avatar
Alp Dener committed
1463
1464
1465
1466
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
1467
            act_params=act_params,
1468
1469
1470
1471
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
            is_outer=True,
Alp Dener's avatar
Alp Dener committed
1472
1473
1474
1475
1476
1477
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

1478
1479
        output = NoScaleTensor(
            data=output,
1480
            amax=updated_amax if output_amax_when_no_scaling else None,
1481
        )
Alp Dener's avatar
Alp Dener committed
1482
        return output, dbias
1483

1484
1485
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1486
        out = dact_lu(
1487
1488
1489
1490
1491
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type,
            quantizer=None,
            act_params=act_params,
1492
1493
1494
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
1495
1496
        )
        return _quantize_dbias_impl(
1497
            out,
1498
1499
1500
1501
1502
1503
            quantizer,
            is_dbias=True,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1504
        )
1505

1506
    is_gated = act_len == 2
1507
1508
1509
1510
1511
1512
1513
1514
1515
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1516
            flatten_axis=-2,
1517
            act_params=act_params,
1518
1519
1520
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
1521
1522
1523
1524
        )
        if war_output is not None:
            return war_output

1525
1526
1527
1528
    if (
        quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
        or quantizer.scaling_mode.is_nvfp4_scaling
    ):
1529
1530
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
1531
1532
            dz=dz,
            x=x,
1533
1534
            activation_type=activation_type,
            quantizer=None,
1535
            act_params=act_params,
1536
1537
1538
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=True,
1539
1540
        )
        out, dbias = _quantize_dbias_impl(
1541
1542
1543
1544
1545
1546
1547
            out,
            is_dbias=is_dbias,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1548
1549
1550
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1551
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1552
1553
1554
1555
1556
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
1557
1558
1559
1560
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type=activation_type,
            act_params=act_params,
1561
1562
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1563
        )
1564
        out, dbias = _quantize_dbias_impl(
1565
1566
1567
1568
1569
1570
1571
            dgated,
            quantizer,
            is_dbias=True,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1572
        )
1573
1574
1575
1576
1577
1578
1579
1580
1581
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1582
    ) = PrimitiveClass.outer_primitive.bind(
1583
1584
1585
        dz,
        x,
        scale,
1586
        amax,
1587
1588
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
1589
        quantize_layout=quantizer.q_layout,
1590
1591
1592
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1593
        act_len=act_len,
1594
        act_params=act_params,
1595
1596
1597
1598
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
        is_outer=True,
1599
    )
1600

1601
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1602
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise:
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1614
1615
1616
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1617
    )
1618

1619
    return out, dbias
1620
1621


1622
1623
def dact_lu(
    dz: jnp.ndarray,
1624
1625
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1626
    quantizer: Optional[Quantizer] = None,
1627
    act_params: Optional[ActivationParams] = None,
1628
1629
1630
    amax_scope: AmaxScope = AmaxScope.LOCAL,
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1631
) -> Union[jnp.ndarray, ScaledTensor]:
1632
    """
1633
    Backward pass for activation with optional quantization.
1634

1635
1636
1637
1638
1639
1640
1641
1642
1643
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
1644
    act_params = act_params if act_params is not None else ActivationParams()
1645
1646
1647
1648
1649
1650
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1651
        act_params=act_params,
1652
1653
1654
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
1655
    )
1656
    return output