activation.py 41.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8

9
import jax
10
import jax.numpy as jnp
11
from jax import dtypes, ffi
12
from jax.experimental.custom_partitioning import SdyShardingRule
13
from jax.sharding import PartitionSpec
14

15
16
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
17
18
19
20

from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
21
    te_dtype_to_jax_dtype,
22
    get_padded_spec,
23
24
25
26
27
28
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
29
from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
30
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
31
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
32
33
from ..quantize import (
    Quantizer,
34
    QuantizeLayout,
35
36
    DelayedScaleQuantizer,
    ScalingMode,
37
38
)

39

40
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
41
42
43


ActivationEnum = {
44
45
46
47
48
49
50
51
52
53
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
54
55
56
}


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _convert_to_activation_function(fn_or_string):
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


72
73
class ActLuPrimitive(BasePrimitive):
    """
74
    ActLu Primitive
75
    """
76

77
78
79
80
81
82
83
84
85
86
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
87
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
88
89
90
91
    inner_primitive = None
    outer_primitive = None

    @staticmethod
92
93
94
95
96
97
98
99
100
101
102
103
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
104
        """
105
        te_act_lu_p abstract
106
        """
107
        del act_enum
108
109
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
110
        assert scale_aval is None or scale_aval.dtype == jnp.float32
111
112
113
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
114
        )
115

116
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
117
118
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
119
120
        )

121
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
122
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
123

124
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
125

126
127
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
128
129
130
131
132
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
133
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
134
135
136
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
137
138

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
139
140

    @staticmethod
141
142
143
144
145
146
147
148
149
150
151
152
153
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
154
        """
155
        te_gated_act_lu_p lowering rules
156
        """
157
        del out_dtype, scale_dtype, act_len, is_outer
158
        x_aval, scale_aval = ctx.avals_in
159
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
160
        assert scale_aval is None or scale_aval.dtype == jnp.float32
161

162
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
163
            ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
164
        )
165
        return out
166
167

    @staticmethod
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
        """
        to describe implementation
        """
        del is_outer
183
        assert ActLuPrimitive.inner_primitive is not None
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
200
201
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
202
203
204
205
206
207
208
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
209

210
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
211
212

    @staticmethod
213
214
215
216
217
218
219
220
221
222
223
224
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
    ):
225
        """
226
        to describe batch rules for vmap
227
        """
228
        del act_len, is_outer
229
230
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
231
232
233
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
234

235
236
237
238
239
240
241
242
243
244
245
246
247
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
            ),
            out_bdims,
        )
248
249

    @staticmethod
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
            is_outer,
        )  # Unused.
270
        x_spec = get_padded_spec(arg_infos[0])
271
272
273
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
274
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
275

276
        if is_2x:
277
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
278
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
279
280
281
282
283
284
285
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
286
287

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
288
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
289
            scale_inv_spec = amax_spec = scale_spec
290
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
291
292
293
294
295
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

296
        scale_inv_sharding = NamedSharding(
297
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
298
        )
299
300
301
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
302
        )
303

304
305
306
307
308
309
310
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
311
312

    @staticmethod
313
314
315
316
317
318
319
320
321
322
323
324
325
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
326
        x_spec = get_padded_spec(arg_infos[0])
327
328
329
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
330
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
331

332
        if is_2x:
333
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
334
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
335
336
337
338
339
340
341
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
342
343

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
344
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
345
            scale_inv_spec = amax_spec = scale_spec
346
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
347
348
349
350
351
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

352
        scale_inv_sharding = NamedSharding(
353
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
354
        )
355
356
357
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
358
        )
359
360

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
361
362
363
364
365
366
367
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
368

369
370
371
372
373
374
375
376
377
378
379
380
381
382
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_outer=True,
                )
            )
383

384
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
385
386
387
388
389
390
391
392
393
394
395
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
396

397
        return mesh, sharded_impl, out_shardings, arg_shardings
398

399
400
401
402
403
404
405
406
407
408
409
410
411
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
412
        del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
Alp Dener's avatar
Alp Dener committed
413
        prefix = "ActLuPrimitive_"
414
415
        x_rank = len(value_types[0].shape)
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
Alp Dener's avatar
Alp Dener committed
416
            x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
417
        )
Alp Dener's avatar
Alp Dener committed
418
        x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
419
420
421
        out = (*x_axes[:-2], x_axes[-1])
        scale_inv = scale_rules.rowwise_rule

Alp Dener's avatar
Alp Dener committed
422
423
        colwise_out = (prefix + "out_colwise",)
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
424
        if is_2x:
Alp Dener's avatar
Alp Dener committed
425
            colwise_scale_inv = scale_rules.colwise_rule
426
427
428
429
430
431
432
433
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(
                    multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
                )
            else:
                colwise_out = out

        # amax is always a unit tensor.
Alp Dener's avatar
Alp Dener committed
434
        amax = (prefix + "amax",)
435
436
437
438

        return SdyShardingRule(
            (
                x_axes,
Alp Dener's avatar
Alp Dener committed
439
                ("…1",),
440
441
442
443
            ),
            (out, colwise_out, scale_inv, colwise_scale_inv, amax),
        )

444

445
register_primitive(ActLuPrimitive)
446
447


448
# TODO(Jeremy): replace is_2x with q_layout
449
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
450
    """
451
    DActLu DBias Cast Transpose Primitive
452
    """
453

454
455
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
456
457
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
458
459
460
461
    inner_primitive = None
    outer_primitive = None

    @staticmethod
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
476
        """
477
        te_dact_dbias_quantize_p abstract
478
        """
479
        del act_enum
480
481
482
483
484
485
486
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
487
        assert scale_aval.dtype == jnp.float32
488
489
490
491
492
493

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

494
        ir_hidden_size = dz_aval.shape[-1]
495
        gi_hidden_size = act_len * x_aval.shape[-1]
496
        assert act_len * ir_hidden_size == gi_hidden_size
497
498
499
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
500
501
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
502

503
504
505
506
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
507
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
508
        if is_2x:
509
            if ScalingMode(scaling_mode).is_tensor_scaling():
510
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
511
            else:
512
513
514
515
516
517
518
519
520
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
521

522
        if is_dbias:
523
            dbias_shape = (act_len, ir_hidden_size)
524
525
526
527
528
529
530
531
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
532
533
534
535
536
537
538
539
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
540

541
542
543
544
545
546
547
548
549
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
550
551

    @staticmethod
552
    def outer_abstract(*args, **kwargs):
553
        """
554
        te_dact_dbias_quantize_p outer abstract
555
        """
556
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
557
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
558
559
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
560
561

    @staticmethod
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
577
        """
578
        te_dact_dbias_quantize_p lowering rules
579
        """
580
        del out_dtype, scale_dtype, act_len, is_outer
581
582
583
584
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
585
        return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
586
587
588
589
            ctx,
            dz,
            x,
            scale,
590
            scaling_mode=scaling_mode.value,
591
592
593
594
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
        )
595
596

    @staticmethod
597
598
599
600
601
602
603
604
605
606
607
608
609
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
610
        """
611
        te_dact_dbias_quantize_p impl
612
        """
613
        del is_outer
614
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
615
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
616
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
632
633
634
635
636
637
638
639
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
640
            )
641
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
642
643

    @staticmethod
644
645
646
647
648
649
650
651
652
653
654
655
656
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
    ):
657
        """
658
        to describe batch rules for vmap
659
        """
660
661
        del is_outer
        check_valid_batch_dims(batch_dims)
662
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
663
664
665
666
667
668
669
670
671
672
673
674
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
675
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
676
677
678
679
680
681
682
683
684
685
686
687
688
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
            ),
            out_bdims,
        )
689
690

    @staticmethod
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del out_dtype, result_infos, act_enum
705
        del scale_dtype, act_len, is_outer
706
        x_spec = get_padded_spec(arg_infos[1])
707
        scale_spec = get_padded_spec(arg_infos[2])
708

709
710
711
712
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

713
        out_sharding = NamedSharding(
714
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
715
716
        )
        if is_2x:
717
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
718
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
719
720
721
722
723
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
724
725
726
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
727
728
        )

729
730
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
731
            mesh,
732
            PartitionSpec(*dbias_spec),
733
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
734
        )
735
736

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
737
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
738
            scale_inv_spec = amax_spec = scale_spec
739
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
740
741
742
743
744
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

745
        scale_inv_sharding = NamedSharding(
746
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
747
748
        )
        amax_sharding = NamedSharding(
749
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
750
        )
751
752
753
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
754
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
755
756
757
758
759
760
761
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
762
            dbias_sharding,
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
781
782
783
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
784
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
785
786
        )

787
        if is_2x:
788
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
789
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
790
791
792
793
794
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
795
796
797
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
798
799
        )

800
801
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
802
            mesh,
803
            PartitionSpec(*dbias_spec),
804
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
805
        )
806
807

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
808
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
809
            scale_inv_spec = amax_spec = scale_spec
810
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
811
812
813
814
815
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

816
        scale_inv_sharding = NamedSharding(
817
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
818
        )
819
820
821
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
822
823
        )

824
825
826
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
827
828
829
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
830
831
        )
        arg_shardings = tuple(arg_shardings)
832
833
834
835
836
837
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
838
            dbias_sharding,
839
        )
840

841
842
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
843
                BaseDActLuDBiasQuantizePrimitive.impl(
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

862
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
863
864
865
866
867
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
868
869
870

        return mesh, sharded_impl, out_shardings, arg_shardings

871
872
873
874
875
876
877
878
879
880
881
882
883
884
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
        mesh,
        value_types,
        result_types,
    ):
885
        del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
Alp Dener's avatar
Alp Dener committed
886
        prefix = "BaseDActLuDBiasQuantizePrimitive_"
887
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
Alp Dener's avatar
Alp Dener committed
888
            len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
889
890
        )
        x_axes = scale_rules.input_spec
Alp Dener's avatar
Alp Dener committed
891
        dz_axes = (*x_axes[:-2], x_axes[-1])
892
        out = x_axes
Alp Dener's avatar
Alp Dener committed
893
        colwise_out = (prefix + "out_colwise",)
894
895
896
897
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
Alp Dener's avatar
Alp Dener committed
898
                colwise_out = out
899

Alp Dener's avatar
Alp Dener committed
900
901
        dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
902
903

        return SdyShardingRule(
Alp Dener's avatar
Alp Dener committed
904
            (dz_axes, x_axes, ("…2",)),
905
906
907
            (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
        )

908

909
910
911
912
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
913
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
914
915
916


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
917
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
918
919


920
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
921
    """
922
    JAX native activation implementation
923
    """
924
925
926
927
928
929
930
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )

    x = jnp.split(inputs, act_len, axis=-2)
931
932
933
934
935
    acts = []
    for idx, act_fn in enumerate(activation_type):
        x_i = _convert_to_activation_function(act_fn)(x[idx])
        acts.append(x_i)
    x = reduce(operator.mul, acts)
936
    x = jnp.squeeze(x, axis=-2)
937
    if quantizer:
938
        return quantizer.quantize(x, flatten_axis=-1)
939
    return NoScaleTensor(data=x, amax=None)
940
941


942
def _jax_quantize_dact_dbias(
943
    dz: Union[jnp.ndarray, NoScaleTensor],
944
945
946
947
948
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
):
949
    """
950
    JAX implementation of dact_lu and dbias with optional quantization
951
    """
952
953
954
955
956
957
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

958
959
960
    _, vjp_func = jax.vjp(
        partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
    )
961
962
963
    # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
    dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
    (dx,) = vjp_func(dz)
964

965
966
    dbias = None
    if is_dbias:
967
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
968

969
    if quantizer is not None:
970
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
971
972
    else:
        dx = dx.astype(x.dtype)
973
        dx = NoScaleTensor(data=dx, amax=None)
974

975
    return dx, dbias
976
977


978
979
980
981
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
982
    amax_scope: AmaxScope = AmaxScope.LOCAL,
983
984
985
986
987
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
988
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
989
990
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
991
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
992
993
994
995
996
997
998
999

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
1000
1001
1002
1003
1004
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1005

1006
1007
    if not ActLuPrimitive.enabled():
        return _jax_act_lu(x, activation_type, quantizer)
1008

1009
    # TE/common does not support colwise-only quantization yet
1010
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1011
        return _jax_act_lu(x, activation_type, quantizer)
1012

1013
1014
1015
1016
1017
1018
1019
1020
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1021
    output_shape = (*x.shape[:-2], x.shape[-1])
1022
1023
1024
1025
1026
1027
1028

    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1029
            act_len=act_len,
1030
            scaling_mode=ScalingMode.NO_SCALING.value,
1031
1032
1033
            is_2x=False,
            scale_dtype=jnp.float32,
            is_outer=True,
1034
        )
1035
        out = out.reshape(output_shape)
1036
1037
1038
1039
        out = NoScaleTensor(
            data=out,
            amax=None,
        )
1040
        return out
1041

1042
1043
1044
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
1045
            x=x,
1046
1047
1048
            activation_type=activation_type,
            quantizer=None,
        )
1049
1050
1051
1052
1053
1054
1055
        out, _ = _quantize_dbias_impl(
            out,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            amax_scope=amax_scope,
        )
1056
1057
        return out

1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1072
        act_len=act_len,
1073
1074
1075
1076
1077
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_outer=True,
    )
1078

1079
1080
1081
1082
1083
1084
1085
1086
1087
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1088
1089
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1090
    )
1091
1092


1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1105
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1106
1107
1108
1109
1110
1111
1112
1113
1114
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1115

1116
1117
1118
1119
1120
1121
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

Alp Dener's avatar
Alp Dener committed
1122
1123
    scale = jnp.empty((), jnp.float32)
    act_type_id = ActivationEnum[activation_type]
1124
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1125
1126
1127
    if not PrimitiveClass.enabled() or (
        quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
    ):
1128
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
1129

Alp Dener's avatar
Alp Dener committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    if quantizer is None:
        output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
            is_outer=True,
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

1151
1152
1153
1154
        output = NoScaleTensor(
            data=output,
            amax=None,
        )
Alp Dener's avatar
Alp Dener committed
1155
        return output, dbias
1156

1157
1158
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1159
1160
1161
1162
        out = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
        )
        return _quantize_dbias_impl(
1163
            out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
1164
        )
1165

1166
    is_gated = act_len == 2
1167
1168
1169
1170
1171
1172
1173
1174
1175
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1176
            flatten_axis=-2,
1177
1178
1179
1180
        )
        if war_output is not None:
            return war_output

1181
1182
1183
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
1184
1185
            dz=dz,
            x=x,
1186
1187
1188
1189
            activation_type=activation_type,
            quantizer=None,
        )
        out, dbias = _quantize_dbias_impl(
1190
            out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
1191
1192
1193
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1194
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1195
1196
1197
1198
1199
1200
1201
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
            dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
        )
1202
1203
1204
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1205
1206
1207
1208
1209
1210
1211
1212
1213
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1214
    ) = PrimitiveClass.outer_primitive.bind(
1215
1216
1217
1218
1219
1220
1221
1222
1223
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1224
        act_len=act_len,
1225
1226
        is_outer=True,
    )
1227

1228
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1229
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1241
1242
1243
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1244
    )
1245

1246
    return out, dbias
1247
1248


1249
1250
def dact_lu(
    dz: jnp.ndarray,
1251
1252
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1253
1254
    quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]:
1255
    """
1256
    Backward pass for activation with optional quantization.
1257

1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1273
    )
1274
    return output