activation.py 41.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from packaging import version
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
17
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
18
19
20
21

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl
31
32
33
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40
41
42
43
44
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

45
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
46
47
48


ActivationEnum = {
49
50
51
52
53
54
55
56
57
58
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
59
60
61
}


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


77
78
class ActLuPrimitive(BasePrimitive):
    """
79
    ActLu Primitive
80
    """
81

82
83
84
85
86
87
88
89
90
91
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
92
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
93
94
95
96
    inner_primitive = None
    outer_primitive = None

    @staticmethod
97
98
99
100
101
102
103
104
105
106
107
108
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
109
        """
110
        te_act_lu_p abstract
111
        """
112
        del act_enum
113
114
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
115
        assert scale_aval is None or scale_aval.dtype == jnp.float32
116
117
118
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
119
        )
120

121
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
122
123
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
124
125
        )

126
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
127
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
128

129
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
130

131
132
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
133
134
135
136
137
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
138
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
139
140
141
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
142
143

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
144
145

    @staticmethod
146
147
148
149
150
151
152
153
154
155
156
157
158
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
159
        """
160
        te_gated_act_lu_p lowering rules
161
        """
162
        del out_dtype, scale_dtype, act_len, is_outer
163
        x_aval, scale_aval = ctx.avals_in
164
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
165
        assert scale_aval is None or scale_aval.dtype == jnp.float32
166

167
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
168
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
169
        )
170
        return out
171
172

    @staticmethod
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
188
        assert ActLuPrimitive.inner_primitive is not None
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
205
206
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
207
208
209
210
211
212
213
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
214

215
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
216
217

    @staticmethod
218
219
220
221
222
223
224
225
226
227
228
229
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
230
        """
231
        to describe batch rules for vmap
232
        """
233
        del act_len, is_outer
234
235
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
236
237
238
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
239

240
241
242
243
244
245
246
247
248
249
250
251
252
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
            ),
            out_bdims,
        )
253
254

    @staticmethod
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
            is_outer,
        )  # Unused.
275
        x_spec = get_padded_spec(arg_infos[0])
276
277
278
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
279
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
280

281
        if is_2x:
282
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
283
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
284
285
286
287
288
289
290
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
291
292

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
293
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
294
            scale_inv_spec = amax_spec = scale_spec
295
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
296
297
298
299
300
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

301
        scale_inv_sharding = NamedSharding(
302
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
303
        )
304
305
306
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
307
        )
308

309
310
311
312
313
314
315
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
316
317

    @staticmethod
318
319
320
321
322
323
324
325
326
327
328
329
330
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
331
        x_spec = get_padded_spec(arg_infos[0])
332
333
334
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
335
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
336

337
        if is_2x:
338
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
339
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
340
341
342
343
344
345
346
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
347
348

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
349
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
350
            scale_inv_spec = amax_spec = scale_spec
351
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
352
353
354
355
356
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

357
        scale_inv_sharding = NamedSharding(
358
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
359
        )
360
361
362
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
363
        )
364
365

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
366
367
368
369
370
371
372
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_outer=True,
                )
            )
388

389
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
390
391
392
393
394
395
396
397
398
399
400
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
401

402
        return mesh, sharded_impl, out_shardings, arg_shardings
403

404
405
406
407
408
409
410
411
412
413
414
415
416
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
417
        del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
Alp Dener's avatar
Alp Dener committed
418
        prefix = "ActLuPrimitive_"
419
420
        x_rank = len(value_types[0].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
Alp Dener's avatar
Alp Dener committed
421
            x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
422
        )
Alp Dener's avatar
Alp Dener committed
423
        x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
424
425
426
        out = (*x_axes[:-2], x_axes[-1])
        scale_inv = scale_rules.rowwise_rule

Alp Dener's avatar
Alp Dener committed
427
428
        colwise_out = (prefix + "out_colwise",)
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
429
        if is_2x:
Alp Dener's avatar
Alp Dener committed
430
            colwise_scale_inv = scale_rules.colwise_rule
431
432
433
434
435
436
437
438
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(
                    multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
                )
            else:
                colwise_out = out

        # amax is always a unit tensor.
Alp Dener's avatar
Alp Dener committed
439
        amax = (prefix + "amax",)
440
441
442
443

        return SdyShardingRule(
            (
                x_axes,
Alp Dener's avatar
Alp Dener committed
444
                ("…1",),
445
446
447
448
            ),
            (out, colwise_out, scale_inv, colwise_scale_inv, amax),
        )

449

450
register_primitive(ActLuPrimitive)
451
452


453
# TODO(Jeremy): replace is_2x with q_layout
454
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
455
    """
456
    DActLu DBias Cast Transpose Primitive
457
    """
458

459
460
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
461
462
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
463
464
465
466
    inner_primitive = None
    outer_primitive = None

    @staticmethod
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
481
        """
482
        te_dact_dbias_quantize_p abstract
483
        """
484
        del act_enum
485
486
487
488
489
490
491
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
492
        assert scale_aval.dtype == jnp.float32
493
494
495
496
497
498

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

499
        ir_hidden_size = dz_aval.shape[-1]
500
        gi_hidden_size = act_len * x_aval.shape[-1]
501
        assert act_len * ir_hidden_size == gi_hidden_size
502
503
504
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
505
506
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
507

508
509
510
511
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
512
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
513
        if is_2x:
514
            if ScalingMode(scaling_mode).is_tensor_scaling():
515
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
516
            else:
517
518
519
520
521
522
523
524
525
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
526

527
        if is_dbias:
528
            dbias_shape = (act_len, ir_hidden_size)
529
530
531
532
533
534
535
536
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
537
538
539
540
541
542
543
544
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
545

546
547
548
549
550
551
552
553
554
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
555
556

    @staticmethod
557
    def outer_abstract(*args, **kwargs):
558
        """
559
        te_dact_dbias_quantize_p outer abstract
560
        """
561
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
562
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
563
564
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
565
566

    @staticmethod
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
582
        """
583
        te_dact_dbias_quantize_p lowering rules
584
        """
585
        del out_dtype, scale_dtype, act_len, is_outer
586
587
588
589
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
590
        return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
591
592
593
594
            ctx,
            dz,
            x,
            scale,
595
            scaling_mode=scaling_mode.value,
596
597
598
599
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
600
601

    @staticmethod
602
603
604
605
606
607
608
609
610
611
612
613
614
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
615
        """
616
        te_dact_dbias_quantize_p impl
617
        """
618
        del is_outer
619
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
620
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
621
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
637
638
639
640
641
642
643
644
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
645
            )
646
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
647
648

    @staticmethod
649
650
651
652
653
654
655
656
657
658
659
660
661
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
662
        """
663
        to describe batch rules for vmap
664
        """
665
666
        del is_outer
        check_valid_batch_dims(batch_dims)
667
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
668
669
670
671
672
673
674
675
676
677
678
679
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
680
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
681
682
683
684
685
686
687
688
689
690
691
692
693
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
694
695

    @staticmethod
696
697
698
699
700
701
702
703
704
705
706
707
708
709
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
710
        del scale_dtype, act_len, is_outer
711
        x_spec = get_padded_spec(arg_infos[1])
712
        scale_spec = get_padded_spec(arg_infos[2])
713

714
715
716
717
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

718
        out_sharding = NamedSharding(
719
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
720
721
        )
        if is_2x:
722
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
723
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
724
725
726
727
728
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
729
730
731
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
732
733
        )

734
735
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
736
            mesh,
737
            PartitionSpec(*dbias_spec),
738
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
739
        )
740
741

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
742
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
743
            scale_inv_spec = amax_spec = scale_spec
744
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
745
746
747
748
749
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

750
        scale_inv_sharding = NamedSharding(
751
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
752
753
        )
        amax_sharding = NamedSharding(
754
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
755
        )
756
757
758
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
759
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
760
761
762
763
764
765
766
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
767
            dbias_sharding,
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
786
787
788
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
789
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
790
791
        )

792
        if is_2x:
793
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
794
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
795
796
797
798
799
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
800
801
802
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
803
804
        )

805
806
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
807
            mesh,
808
            PartitionSpec(*dbias_spec),
809
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
810
        )
811
812

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
813
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
814
            scale_inv_spec = amax_spec = scale_spec
815
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
816
817
818
819
820
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

821
        scale_inv_sharding = NamedSharding(
822
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
823
        )
824
825
826
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
827
828
        )

829
830
831
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
832
833
834
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
835
836
        )
        arg_shardings = tuple(arg_shardings)
837
838
839
840
841
842
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
843
            dbias_sharding,
844
        )
845

846
847
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
848
                BaseDActLuDBiasQuantizePrimitive.impl(
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

867
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
868
869
870
871
872
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
873
874
875

        return mesh, sharded_impl, out_shardings, arg_shardings

876
877
878
879
880
881
882
883
884
885
886
887
888
889
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
890
        del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
Alp Dener's avatar
Alp Dener committed
891
        prefix = "BaseDActLuDBiasQuantizePrimitive_"
892
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
Alp Dener's avatar
Alp Dener committed
893
            len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
894
895
        )
        x_axes = scale_rules.input_spec
Alp Dener's avatar
Alp Dener committed
896
        dz_axes = (*x_axes[:-2], x_axes[-1])
897
        out = x_axes
Alp Dener's avatar
Alp Dener committed
898
        colwise_out = (prefix + "out_colwise",)
899
900
901
902
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
Alp Dener's avatar
Alp Dener committed
903
                colwise_out = out
904

Alp Dener's avatar
Alp Dener committed
905
906
        dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
907
908

        return SdyShardingRule(
Alp Dener's avatar
Alp Dener committed
909
            (dz_axes, x_axes, ("…2",)),
910
911
912
            (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
        )

913

914
915
916
917
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
918
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
919
920
921


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
922
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
923
924


925
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
926
    """
927
    JAX native activation implementation
928
    """
929
930
931
932
933
934
935
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
936
937
938
939
940
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
941
    x = jnp.squeeze(x, axis=-2)
942
    if quantizer:
943
        return quantizer.quantize(x, flatten_axis=-1)
944
    return x
945
946


947
948
949
950
951
952
953
def _jax_quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
954
    """
955
    JAX implementation of dact_lu and dbias with optional quantization
956
    """
957
958
959
960
961
962
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

963
964
965
966
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
    (dx,) = vjp_func(dz.astype(jnp.float32))
967

968
969
    dbias = None
    if is_dbias:
970
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
971

972
    if quantizer is not None:
973
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
974
975
    else:
        dx = dx.astype(x.dtype)
976

977
    return dx, dbias
978
979


980
981
982
983
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
Alp Dener's avatar
Alp Dener committed
984
    noop_scaled_tensor: bool = False,
985
986
987
988
989
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
990
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
991
992
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
Alp Dener's avatar
Alp Dener committed
993
        noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
994
995
996
997
998
999
1000
1001

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
1002
1003
1004
1005
1006
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1007

1008
1009
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
1010

1011
    # TE/common does not support colwise-only quantization yet
1012
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1013
        return _jax_act_lu(x, activation_type, quantizer)
1014

1015
1016
1017
1018
1019
1020
1021
1022
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1023
    output_shape = (*x.shape[:-2], x.shape[-1])
1024
1025
1026
1027
1028
1029
1030

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1031
            act_len=act_len,
1032
            scaling_mode=ScalingMode.NO_SCALING.value,
1033
1034
1035
            is_2x=False,
            scale_dtype=jnp.float32,
            is_outer=True,
1036
        )
1037
        out = out.reshape(output_shape)
Alp Dener's avatar
Alp Dener committed
1038
1039
1040
1041
        if noop_scaled_tensor:
            return ScaledTensorFactory.create_2x(
                out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype
            )
1042
        return out
1043

1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
        return out

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1068
        act_len=act_len,
1069
1070
1071
1072
1073
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_outer=True,
    )
1074

1075
1076
1077
1078
1079
1080
1081
1082
1083
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1084
1085
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1086
    )
1087
1088


1089
1090
1091
1092
1093
1094
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
Alp Dener's avatar
Alp Dener committed
1095
    noop_scaled_tensor: bool = False,
1096
1097
1098
1099
1100
1101
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1102
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1103
1104
1105
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.
Alp Dener's avatar
Alp Dener committed
1106
        noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
1107
1108
1109
1110
1111
1112

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1113

1114
1115
1116
1117
1118
1119
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

Alp Dener's avatar
Alp Dener committed
1120
1121
    scale = jnp.empty((), jnp.float32)
    act_type_id = ActivationEnum[activation_type]
1122
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1123
1124
1125
    if not PrimitiveClass.enabled() or (
        quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
    ):
1126
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1127

Alp Dener's avatar
Alp Dener committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
    if quantizer is None:
        output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
            is_outer=True,
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

        if noop_scaled_tensor:
            return (
                ScaledTensorFactory.create_2x(
                    output,
                    None,
                    output,
                    None,
                    ScalingMode.NO_SCALING,
                    dq_dtype=output.dtype,
                ),
                dbias,
            )

        return output, dbias
1163

1164
1165
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1166
1167
1168
1169
1170
1171
        out = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
        )
        return _quantize_dbias_impl(
            out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1172

1173
    is_gated = act_len == 2
1174
1175
1176
1177
1178
1179
1180
1181
1182
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1183
            flatten_axis=-2,
1184
1185
1186
1187
        )
        if war_output is not None:
            return war_output

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
            dz=dz.astype(jnp.float32),
            x=x.astype(jnp.float32),
            activation_type=activation_type,
            quantizer=None,
        )
        out, dbias = _quantize_dbias_impl(
            out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1201
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1202
1203
1204
1205
1206
1207
1208
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1209
1210
1211
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1212
1213
1214
1215
1216
1217
1218
1219
1220
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1221
    ) = PrimitiveClass.outer_primitive.bind(
1222
1223
1224
1225
1226
1227
1228
1229
1230
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1231
        act_len=act_len,
1232
1233
        is_outer=True,
    )
1234

1235
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1236
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1248
1249
1250
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1251
    )
1252

1253
    return out, dbias
1254
1255


1256
1257
def dact_lu(
    dz: jnp.ndarray,
1258
1259
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1260
    quantizer: Optional[Quantizer] = None,
Alp Dener's avatar
Alp Dener committed
1261
    noop_scale_tensor: bool = False,
1262
) -> Union[jnp.ndarray, ScaledTensor]:
1263
    """
1264
    Backward pass for activation with optional quantization.
1265

1266
1267
1268
1269
1270
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.
Alp Dener's avatar
Alp Dener committed
1271
        noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
Alp Dener's avatar
Alp Dener committed
1282
        noop_scaled_tensor=noop_scale_tensor,
1283
    )
1284
    return output