activation.py 41.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8

9
import jax
10
import jax.numpy as jnp
11
from jax import dtypes, ffi
12
from jax.experimental.custom_partitioning import SdyShardingRule
13
from jax.sharding import PartitionSpec
14

15
16
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
17
18
19
20

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
21
    te_dtype_to_jax_dtype,
22
    get_padded_spec,
23
24
25
26
27
28
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
29
from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
31
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
32
33
from ..quantize import (
    Quantizer,
34
    QuantizeLayout,
35
36
    DelayedScaleQuantizer,
    ScalingMode,
37
38
)

39

40
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
41
42
43


ActivationEnum = {
44
45
46
47
48
49
50
51
52
53
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
54
55
56
}


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


72
73
class ActLuPrimitive(BasePrimitive):
    """
74
    ActLu Primitive
75
    """
76

77
78
79
80
81
82
83
84
85
86
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
87
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
88
89
90
91
    inner_primitive = None
    outer_primitive = None

    @staticmethod
92
93
94
95
96
97
98
99
100
101
102
103
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
104
        """
105
        te_act_lu_p abstract
106
        """
107
        del act_enum
108
109
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
110
        assert scale_aval is None or scale_aval.dtype == jnp.float32
111
112
113
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
114
        )
115

116
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
117
118
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
119
120
        )

121
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
122
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
123

124
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
125

126
127
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
128
129
130
131
132
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
133
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
134
135
136
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
137
138

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
139
140

    @staticmethod
141
142
143
144
145
146
147
148
149
150
151
152
153
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
154
        """
155
        te_gated_act_lu_p lowering rules
156
        """
157
        del out_dtype, scale_dtype, act_len, is_outer
158
        x_aval, scale_aval = ctx.avals_in
159
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
160
        assert scale_aval is None or scale_aval.dtype == jnp.float32
161

162
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
163
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
164
        )
165
        return out
166
167

    @staticmethod
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
183
        assert ActLuPrimitive.inner_primitive is not None
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
200
201
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
202
203
204
205
206
207
208
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
209

210
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
211
212

    @staticmethod
213
214
215
216
217
218
219
220
221
222
223
224
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
225
        """
226
        to describe batch rules for vmap
227
        """
228
        del act_len, is_outer
229
230
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
231
232
233
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
234

235
236
237
238
239
240
241
242
243
244
245
246
247
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
            ),
            out_bdims,
        )
248
249

    @staticmethod
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
            is_outer,
        )  # Unused.
270
        x_spec = get_padded_spec(arg_infos[0])
271
272
273
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
274
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
275

276
        if is_2x:
277
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
278
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
279
280
281
282
283
284
285
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
286
287

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
288
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
289
            scale_inv_spec = amax_spec = scale_spec
290
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
291
292
293
294
295
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

296
        scale_inv_sharding = NamedSharding(
297
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
298
        )
299
300
301
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
302
        )
303

304
305
306
307
308
309
310
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
311
312

    @staticmethod
313
314
315
316
317
318
319
320
321
322
323
324
325
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
326
        x_spec = get_padded_spec(arg_infos[0])
327
328
329
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
330
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
331

332
        if is_2x:
333
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
334
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
335
336
337
338
339
340
341
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
342
343

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
344
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
345
            scale_inv_spec = amax_spec = scale_spec
346
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
347
348
349
350
351
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

352
        scale_inv_sharding = NamedSharding(
353
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
354
        )
355
356
357
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
358
        )
359
360

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
361
362
363
364
365
366
367
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
368

369
370
371
372
373
374
375
376
377
378
379
380
381
382
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_outer=True,
                )
            )
383

384
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
385
386
387
388
389
390
391
392
393
394
395
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
396

397
        return mesh, sharded_impl, out_shardings, arg_shardings
398

399
400
401
402
403
404
405
406
407
408
409
410
411
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
412
        del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
413
414
415
416
        prefix = "ActLu_"
        input_shape = value_types[0].shape
        output_shape = input_shape[:-2] + input_shape[-1:]
        # Here we pass len of output so that the scales are propagated correctly
417
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
418
            output_shape, unique_var=prefix + "x", flatten_axis=-1
419
        )
420
421
422
423
        x_axes = scale_rules.input_spec
        # Correct input spec with act dim
        x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:]
        out = scale_rules.input_spec
424

Alp Dener's avatar
Alp Dener committed
425
426
        colwise_out = (prefix + "out_colwise",)
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
427
        if is_2x:
Alp Dener's avatar
Alp Dener committed
428
            colwise_scale_inv = scale_rules.colwise_rule
429
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
430
                colwise_out = multidim_transpose(out, transpose_axis=-1)
431
432
            else:
                colwise_out = out
433
                colwise_scale_inv = scale_rules.colwise_rule
434

Alp Dener's avatar
Alp Dener committed
435
        amax = (prefix + "amax",)
436
437
438
439

        return SdyShardingRule(
            (
                x_axes,
Alp Dener's avatar
Alp Dener committed
440
                ("…1",),
441
            ),
442
443
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
444
445
        )

446

447
register_primitive(ActLuPrimitive)
448
449


450
# TODO(Jeremy): replace is_2x with q_layout
451
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
452
    """
453
    DActLu DBias Cast Transpose Primitive
454
    """
455

456
457
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
458
459
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
460
461
462
463
    inner_primitive = None
    outer_primitive = None

    @staticmethod
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
478
        """
479
        te_dact_dbias_quantize_p abstract
480
        """
481
        del act_enum
482
483
484
485
486
487
488
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
489
        assert scale_aval.dtype == jnp.float32
490
491
492
493
494
495

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

496
        ir_hidden_size = dz_aval.shape[-1]
497
        gi_hidden_size = act_len * x_aval.shape[-1]
498
        assert act_len * ir_hidden_size == gi_hidden_size
499
500
501
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
502
503
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
504

505
506
507
508
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
509
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
510
        if is_2x:
511
            if ScalingMode(scaling_mode).is_tensor_scaling():
512
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
513
            else:
514
515
516
517
518
519
520
521
522
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
523

524
        if is_dbias:
525
            dbias_shape = (act_len, ir_hidden_size)
526
527
528
529
530
531
532
533
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
534
535
536
537
538
539
540
541
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
542

543
544
545
546
547
548
549
550
551
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
552
553

    @staticmethod
554
    def outer_abstract(*args, **kwargs):
555
        """
556
        te_dact_dbias_quantize_p outer abstract
557
        """
558
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
559
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
560
561
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
562
563

    @staticmethod
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
579
        """
580
        te_dact_dbias_quantize_p lowering rules
581
        """
582
        del out_dtype, scale_dtype, act_len, is_outer
583
584
585
586
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
587
        return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
588
589
590
591
            ctx,
            dz,
            x,
            scale,
592
            scaling_mode=scaling_mode.value,
593
594
595
596
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
597
598

    @staticmethod
599
600
601
602
603
604
605
606
607
608
609
610
611
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
612
        """
613
        te_dact_dbias_quantize_p impl
614
        """
615
        del is_outer
616
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
617
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
618
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
634
635
636
637
638
639
640
641
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
642
            )
643
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
644
645

    @staticmethod
646
647
648
649
650
651
652
653
654
655
656
657
658
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
659
        """
660
        to describe batch rules for vmap
661
        """
662
663
        del is_outer
        check_valid_batch_dims(batch_dims)
664
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
665
666
667
668
669
670
671
672
673
674
675
676
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
677
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
678
679
680
681
682
683
684
685
686
687
688
689
690
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
691
692

    @staticmethod
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
707
        del scale_dtype, act_len, is_outer
708
        x_spec = get_padded_spec(arg_infos[1])
709
        scale_spec = get_padded_spec(arg_infos[2])
710

711
712
713
714
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

715
        out_sharding = NamedSharding(
716
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
717
718
        )
        if is_2x:
719
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
720
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
721
722
723
724
725
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
726
727
728
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
729
730
        )

731
732
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
733
            mesh,
734
            PartitionSpec(*dbias_spec),
735
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
736
        )
737
738

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
739
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
740
            scale_inv_spec = amax_spec = scale_spec
741
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
742
743
744
745
746
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

747
        scale_inv_sharding = NamedSharding(
748
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
749
750
        )
        amax_sharding = NamedSharding(
751
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
752
        )
753
754
755
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
756
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
757
758
759
760
761
762
763
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
764
            dbias_sharding,
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
783
784
785
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
786
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
787
788
        )

789
        if is_2x:
790
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
791
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
792
793
794
795
796
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
797
798
799
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
800
801
        )

802
803
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
804
            mesh,
805
            PartitionSpec(*dbias_spec),
806
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
807
        )
808
809

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
810
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
811
            scale_inv_spec = amax_spec = scale_spec
812
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
813
814
815
816
817
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

818
        scale_inv_sharding = NamedSharding(
819
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
820
        )
821
822
823
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
824
825
        )

826
827
828
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
829
830
831
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
832
833
        )
        arg_shardings = tuple(arg_shardings)
834
835
836
837
838
839
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
840
            dbias_sharding,
841
        )
842

843
844
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
845
                BaseDActLuDBiasQuantizePrimitive.impl(
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

864
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
865
866
867
868
869
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
870
871
872

        return mesh, sharded_impl, out_shardings, arg_shardings

873
874
875
876
877
878
879
880
881
882
883
884
885
886
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
887
        del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
888
        prefix = "DActLuDBias_"
889
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
890
            value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
891
892
        )
        x_axes = scale_rules.input_spec
Alp Dener's avatar
Alp Dener committed
893
        dz_axes = (*x_axes[:-2], x_axes[-1])
894
        out = x_axes
895

Alp Dener's avatar
Alp Dener committed
896
        colwise_out = (prefix + "out_colwise",)
897
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
898
899
900
901
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
Alp Dener's avatar
Alp Dener committed
902
                colwise_out = out
903
                colwise_scale_inv = scale_rules.colwise_rule
904

Alp Dener's avatar
Alp Dener committed
905
906
        dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
907
908

        return SdyShardingRule(
Alp Dener's avatar
Alp Dener committed
909
            (dz_axes, x_axes, ("…2",)),
910
911
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
912
913
        )

914

915
916
917
918
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
919
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
920
921
922


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
923
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
924
925


926
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
927
    """
928
    JAX native activation implementation
929
    """
930
931
932
933
934
935
936
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
937
938
939
940
941
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
942
    x = jnp.squeeze(x, axis=-2)
943
    if quantizer:
944
        return quantizer.quantize(x, flatten_axis=-1)
945
    return NoScaleTensor(data=x, amax=None)
946
947


948
def _jax_quantize_dact_dbias(
949
    dz: Union[jnp.ndarray, NoScaleTensor],
950
951
952
953
954
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
955
    """
956
    JAX implementation of dact_lu and dbias with optional quantization
957
    """
958
959
960
961
962
963
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

964
965
966
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
967
968
969
    # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
    dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
    (dx,) = vjp_func(dz)
970

971
972
    dbias = None
    if is_dbias:
973
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
974

975
    if quantizer is not None:
976
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
977
978
    else:
        dx = dx.astype(x.dtype)
979
        dx = NoScaleTensor(data=dx, amax=None)
980

981
    return dx, dbias
982
983


984
985
986
987
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
988
    amax_scope: AmaxScope = AmaxScope.LOCAL,
989
990
991
992
993
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
994
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
995
996
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
997
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
998
999
1000
1001
1002
1003
1004
1005

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
1006
1007
1008
1009
1010
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1011

1012
1013
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
1014

1015
    # TE/common does not support colwise-only quantization yet
1016
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1017
        return _jax_act_lu(x, activation_type, quantizer)
1018

1019
1020
1021
1022
1023
1024
1025
1026
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1027
    output_shape = (*x.shape[:-2], x.shape[-1])
1028
1029
1030
1031
1032
1033
1034

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1035
            act_len=act_len,
1036
            scaling_mode=ScalingMode.NO_SCALING.value,
1037
1038
1039
            is_2x=False,
            scale_dtype=jnp.float32,
            is_outer=True,
1040
        )
1041
        out = out.reshape(output_shape)
1042
1043
1044
1045
        out = NoScaleTensor(
            data=out,
            amax=None,
        )
1046
        return out
1047

1048
1049
1050
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
1051
            x=x,
1052
1053
1054
            activation_type=activation_type,
            quantizer=None,
        )
1055
1056
1057
1058
1059
1060
1061
        out, _ = _quantize_dbias_impl(
            out,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            amax_scope=amax_scope,
        )
1062
1063
        return out

1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1078
        act_len=act_len,
1079
1080
1081
1082
1083
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_outer=True,
    )
1084

1085
1086
1087
1088
1089
1090
1091
1092
1093
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1094
1095
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1096
    )
1097
1098


1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1111
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1112
1113
1114
1115
1116
1117
1118
1119
1120
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1121

1122
1123
1124
1125
1126
1127
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

Alp Dener's avatar
Alp Dener committed
1128
1129
    scale = jnp.empty((), jnp.float32)
    act_type_id = ActivationEnum[activation_type]
1130
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1131
1132
1133
    if not PrimitiveClass.enabled() or (
        quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
    ):
1134
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1135

Alp Dener's avatar
Alp Dener committed
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    if quantizer is None:
        output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
            is_outer=True,
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

1157
1158
1159
1160
        output = NoScaleTensor(
            data=output,
            amax=None,
        )
Alp Dener's avatar
Alp Dener committed
1161
        return output, dbias
1162

1163
1164
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1165
1166
1167
1168
        out = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
        )
        return _quantize_dbias_impl(
1169
            out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
1170
        )
1171

1172
    is_gated = act_len == 2
1173
1174
1175
1176
1177
1178
1179
1180
1181
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1182
            flatten_axis=-2,
1183
1184
1185
1186
        )
        if war_output is not None:
            return war_output

1187
1188
1189
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
1190
1191
            dz=dz,
            x=x,
1192
1193
1194
1195
            activation_type=activation_type,
            quantizer=None,
        )
        out, dbias = _quantize_dbias_impl(
1196
            out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
1197
1198
1199
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1200
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1201
1202
1203
1204
1205
1206
1207
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1208
1209
1210
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1211
1212
1213
1214
1215
1216
1217
1218
1219
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1220
    ) = PrimitiveClass.outer_primitive.bind(
1221
1222
1223
1224
1225
1226
1227
1228
1229
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1230
        act_len=act_len,
1231
1232
        is_outer=True,
    )
1233

1234
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1235
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1247
1248
1249
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1250
    )
1251

1252
    return out, dbias
1253
1254


1255
1256
def dact_lu(
    dz: jnp.ndarray,
1257
1258
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1259
1260
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1261
    """
1262
    Backward pass for activation with optional quantization.
1263

1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1279
    )
1280
    return output