"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "146695a904526adc70c97615f7c395c4b3115c18"
activation.py 40.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
17
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl
31
32
33
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40
41
42
43
44
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

45
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
46
47
48


ActivationEnum = {
49
50
51
52
53
54
55
56
57
58
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
59
60
61
}


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


77
78
class ActLuPrimitive(BasePrimitive):
    """
79
    ActLu Primitive
80
    """
81

82
83
84
85
86
87
88
89
90
91
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
92
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
93
94
95
96
    inner_primitive = None
    outer_primitive = None

    @staticmethod
97
98
99
100
101
102
103
104
105
106
107
108
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
109
        """
110
        te_act_lu_p abstract
111
        """
112
        del act_enum
113
114
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
115
        assert scale_aval is None or scale_aval.dtype == jnp.float32
116
117
118
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
119
        )
120

121
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
122
123
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
124
125
        )

126
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
127
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
128

129
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
130

131
132
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
133
134
135
136
137
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
138
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
139
140
141
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
142
143

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
144
145

    @staticmethod
146
147
148
149
150
151
152
153
154
155
156
157
158
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
159
        """
160
        te_gated_act_lu_p lowering rules
161
        """
162
        del out_dtype, scale_dtype, act_len, is_outer
163
        x_aval, scale_aval = ctx.avals_in
164
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
165
        assert scale_aval is None or scale_aval.dtype == jnp.float32
166

167
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
168
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
169
        )
170
        return out
171
172

    @staticmethod
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
188
        assert ActLuPrimitive.inner_primitive is not None
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
205
206
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
207
208
209
210
211
212
213
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
214

215
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
216
217

    @staticmethod
218
219
220
221
222
223
224
225
226
227
228
229
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
230
        """
231
        to describe batch rules for vmap
232
        """
233
        del act_len, is_outer
234
235
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
236
237
238
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
239

240
241
242
243
244
245
246
247
248
249
250
251
252
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
            ),
            out_bdims,
        )
253
254

    @staticmethod
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
            is_outer,
        )  # Unused.
275
        x_spec = get_padded_spec(arg_infos[0])
276
277
278
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
279
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
280

281
        if is_2x:
282
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
283
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
284
285
286
287
288
289
290
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
291
292

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
293
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
294
            scale_inv_spec = amax_spec = scale_spec
295
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
296
297
298
299
300
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

301
        scale_inv_sharding = NamedSharding(
302
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
303
        )
304
305
306
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
307
        )
308

309
310
311
312
313
314
315
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
316
317

    @staticmethod
318
319
320
321
322
323
324
325
326
327
328
329
330
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
331
        x_spec = get_padded_spec(arg_infos[0])
332
333
334
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
335
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
336

337
        if is_2x:
338
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
339
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
340
341
342
343
344
345
346
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
347
348

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
349
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
350
            scale_inv_spec = amax_spec = scale_spec
351
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
352
353
354
355
356
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

357
        scale_inv_sharding = NamedSharding(
358
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
359
        )
360
361
362
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
363
        )
364
365

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
366
367
368
369
370
371
372
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_outer=True,
                )
            )
388

389
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
390
391
392
393
394
395
396
397
398
399
400
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
401

402
        return mesh, sharded_impl, out_shardings, arg_shardings
403

404
405
406
407
408
409
410
411
412
413
414
415
416
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
417
        del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
418
419
420

        x_rank = len(value_types[0].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
421
            x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        )
        x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
        out = (*x_axes[:-2], x_axes[-1])
        scale_inv = scale_rules.rowwise_rule
        colwise_scale_inv = scale_rules.colwise_rule

        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(
                    multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
                )
            else:
                colwise_out = out
        else:
            colwise_out = ("j",)
            colwise_scale_inv = ("k",)

        # amax is always a unit tensor.
        amax = ("l",)

        return SdyShardingRule(
            (
                x_axes,
                "…1",
            ),
            (out, colwise_out, scale_inv, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
        )

451

452
register_primitive(ActLuPrimitive)
453
454


455
# TODO(Jeremy): replace is_2x with q_layout
456
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
457
    """
458
    DActLu DBias Cast Transpose Primitive
459
    """
460

461
462
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
463
464
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
465
466
467
468
    inner_primitive = None
    outer_primitive = None

    @staticmethod
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
483
        """
484
        te_dact_dbias_quantize_p abstract
485
        """
486
        del act_enum
487
488
489
490
491
492
493
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
494
        assert scale_aval.dtype == jnp.float32
495
496
497
498
499
500

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

501
        ir_hidden_size = dz_aval.shape[-1]
502
        gi_hidden_size = act_len * x_aval.shape[-1]
503
        assert act_len * ir_hidden_size == gi_hidden_size
504
505
506
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
507
508
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
509

510
511
512
513
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
514
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
515
        if is_2x:
516
            if ScalingMode(scaling_mode).is_tensor_scaling():
517
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
518
            else:
519
520
521
522
523
524
525
526
527
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
528

529
        if is_dbias:
530
            dbias_shape = (act_len, ir_hidden_size)
531
532
533
534
535
536
537
538
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
539
540
541
542
543
544
545
546
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
547

548
549
550
551
552
553
554
555
556
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
557
558

    @staticmethod
559
    def outer_abstract(*args, **kwargs):
560
        """
561
        te_dact_dbias_quantize_p outer abstract
562
        """
563
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
564
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
565
566
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
567
568

    @staticmethod
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
584
        """
585
        te_dact_dbias_quantize_p lowering rules
586
        """
587
        del out_dtype, scale_dtype, act_len, is_outer
588
589
590
591
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
592
        return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
593
594
595
596
            ctx,
            dz,
            x,
            scale,
597
            scaling_mode=scaling_mode.value,
598
599
600
601
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
602
603

    @staticmethod
604
605
606
607
608
609
610
611
612
613
614
615
616
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
617
        """
618
        te_dact_dbias_quantize_p impl
619
        """
620
        del is_outer
621
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
622
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
623
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
639
640
641
642
643
644
645
646
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
647
            )
648
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
649
650

    @staticmethod
651
652
653
654
655
656
657
658
659
660
661
662
663
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
664
        """
665
        to describe batch rules for vmap
666
        """
667
668
        del is_outer
        check_valid_batch_dims(batch_dims)
669
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
670
671
672
673
674
675
676
677
678
679
680
681
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
682
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
683
684
685
686
687
688
689
690
691
692
693
694
695
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
696
697

    @staticmethod
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
712
        del scale_dtype, act_len, is_outer
713
        x_spec = get_padded_spec(arg_infos[1])
714
        scale_spec = get_padded_spec(arg_infos[2])
715

716
717
718
719
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

720
        out_sharding = NamedSharding(
721
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
722
723
        )
        if is_2x:
724
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
725
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
726
727
728
729
730
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
731
732
733
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
734
735
        )

736
737
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
738
            mesh,
739
            PartitionSpec(*dbias_spec),
740
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
741
        )
742
743

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
744
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
745
            scale_inv_spec = amax_spec = scale_spec
746
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
747
748
749
750
751
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

752
        scale_inv_sharding = NamedSharding(
753
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
754
755
        )
        amax_sharding = NamedSharding(
756
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
757
        )
758
759
760
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
761
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
762
763
764
765
766
767
768
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
769
            dbias_sharding,
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
788
789
790
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
791
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
792
793
        )

794
        if is_2x:
795
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
796
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
797
798
799
800
801
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
802
803
804
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
805
806
        )

807
808
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
809
            mesh,
810
            PartitionSpec(*dbias_spec),
811
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
812
        )
813
814

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
815
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
816
            scale_inv_spec = amax_spec = scale_spec
817
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
818
819
820
821
822
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

823
        scale_inv_sharding = NamedSharding(
824
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
825
        )
826
827
828
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
829
830
        )

831
832
833
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
834
835
836
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
837
838
        )
        arg_shardings = tuple(arg_shardings)
839
840
841
842
843
844
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
845
            dbias_sharding,
846
        )
847

848
849
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
850
                BaseDActLuDBiasQuantizePrimitive.impl(
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

869
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
870
871
872
873
874
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
875
876
877

        return mesh, sharded_impl, out_shardings, arg_shardings

878
879
880
881
882
883
884
885
886
887
888
889
890
891
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
892
        del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
893
894
895

        x_rank = len(value_types[1].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
896
            x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
        )
        x_axes = scale_rules.input_spec
        out = x_axes
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
                colwise_out = tuple(x_axes)
        else:
            colwise_out = ("j",)

        dbias = x_axes[-2:] if is_dbias else ("k",)
        amax = ("…4",)

        return SdyShardingRule(
            (("…0",), tuple(x_axes), ("…2",)),
            (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
            **scale_rules.factor_sizes,
        )

917

918
919
920
921
922
923
924
925
926
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
927
928


929
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
930
    """
931
    JAX native activation implementation
932
    """
933
934
935
936
937
938
939
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
940
941
942
943
944
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
945
    x = jnp.squeeze(x, axis=-2)
946
    if quantizer:
947
        return quantizer.quantize(x, flatten_axis=-1)
948
    return x
949
950


951
952
953
954
955
956
957
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
958
    """
959
    JAX implementation of dact_lu and dbias with optional quantization
960
    """
961
962
963
964
965
966
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

967
968
969
970
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
971

972
973
    dbias = None
    if is_dbias:
974
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
975

976
    if quantizer is not None:
977
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
978
979
    else:
        dx = dx.astype(x.dtype)
980

981
    return dx, dbias
982
983


984
985
986
987
988
989
990
991
992
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
993
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
994
995
996
997
998
999
1000
1001
1002
1003
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
1004
1005
1006
1007
1008
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1009

1010
1011
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
1012

1013
    # TE/common does not support colwise-only quantization yet
1014
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1015
        return _jax_act_lu(x, activation_type, quantizer)
1016

1017
1018
1019
1020
1021
1022
1023
1024
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1025
    output_shape = (*x.shape[:-2], x.shape[-1])
1026
1027
1028
1029
1030
1031
1032

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1033
            act_len=act_len,
1034
            scaling_mode=ScalingMode.NO_SCALING.value,
1035
1036
1037
            is_2x=False,
            scale_dtype=jnp.float32,
            is_outer=True,
1038
        )
1039
1040
        out = out.reshape(output_shape)
        return out
1041

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
        return out

1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1066
        act_len=act_len,
1067
1068
1069
1070
1071
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_outer=True,
    )
1072

1073
1074
1075
1076
1077
1078
1079
1080
1081
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1082
1083
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1084
    )
1085
1086


1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1099
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1100
1101
1102
1103
1104
1105
1106
1107
1108
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1109

1110
1111
1112
1113
1114
1115
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1116
1117
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
    if not PrimitiveClass.enabled():
1118
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1119

1120
    # TE/common does not support colwise-only quantization yet
1121
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1122
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1123

1124
1125
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1126
1127
1128
1129
1130
1131
        out = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
        )
        return _quantize_dbias_impl(
            out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1132

1133
    is_gated = act_len == 2
1134
1135
1136
1137
1138
1139
1140
1141
1142
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1143
            flatten_axis=-2,
1144
1145
1146
1147
1148
1149
1150
1151
1152
        )
        if war_output is not None:
            return war_output

    scale = jnp.empty((), jnp.float32)

    act_type_id = ActivationEnum[activation_type]

    if quantizer is None:
1153
        output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
1154
1155
1156
1157
1158
1159
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
1160
            scaling_mode=ScalingMode.NO_SCALING.value,
1161
1162
1163
1164
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
1165
            act_len=act_len,
1166
1167
1168
1169
            is_outer=True,
        )
        dbias = None
        if is_dbias:
1170
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
1171
1172
        return output.astype(x.dtype), dbias

1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
            dz=dz.astype(jnp.float32),
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, dbias = _quantize_dbias_impl(
            out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
        )
        return out, dbias

1186
1187
1188
1189
1190
1191
1192
1193
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1194
1195
1196
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1197
1198
1199
1200
1201
1202
1203
1204
1205
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1206
    ) = PrimitiveClass.outer_primitive.bind(
1207
1208
1209
1210
1211
1212
1213
1214
1215
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1216
        act_len=act_len,
1217
1218
        is_outer=True,
    )
1219

1220
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1221
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1233
1234
1235
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1236
    )
1237

1238
    return out, dbias
1239
1240


1241
1242
def dact_lu(
    dz: jnp.ndarray,
1243
1244
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1245
1246
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1247
    """
1248
    Backward pass for activation with optional quantization.
1249

1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1265
    )
1266
    return output