activation.py 54.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from dataclasses import dataclass
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import numpy as np
17
18
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
19
20
21
from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
31
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
32
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
33
34
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40

41
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
42
43
44


ActivationEnum = {
45
46
47
48
49
50
51
52
53
54
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
55
    ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU,
56
57
58
}


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@dataclass(frozen=True)
class ClampedSwigluParams:
    """Parameters for the Clamped SwiGLU activation function
    used in GPT OSS."""

    limit: float = 7.0
    alpha: float = 1.702

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work.

        Returns:
            int: Hash value of the dataclass instance.
        """
        return hash((self.limit, self.alpha))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.

        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}


@dataclass(frozen=True)
class ActivationParams:
    """Parameters for various activation functions.
    Currently only Clamped SwiGLU activation has parameters.
    """

    clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()

    @staticmethod
    def create(activation_type, **kwargs):
        """Factory method to create ActivationParams based on activation_type."""
        CLAMPED_ACTIVATION_TYPES = {
            ("clamped_silu", "clamped_linear"),
            "clamped_silu",
            "clamped_linear",
        }
        if activation_type in CLAMPED_ACTIVATION_TYPES:
            return ActivationParams(ClampedSwigluParams(**kwargs))
        return ActivationParams()  # Default params for activations without parameters

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work"""
        return hash((self.clamped_swiglu,))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.
        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}


def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
119
120
121
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
122
123
124
125
126
127
    if fn_or_string == "clamped_linear":
        # This function is used for ClampedSwiGLU
        # used in GPT OSS where the gates are not only clamped
        # but also shifted by +1
        limit = act_params.clamped_swiglu.limit
        return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
128
129
130
131
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
132
133
134
135
    if fn_or_string == "clamped_silu":
        limit = act_params.clamped_swiglu.limit
        alpha = act_params.clamped_swiglu.alpha
        return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit)
136
137
138
139
140
141
142
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


143
144
class ActLuPrimitive(BasePrimitive):
    """
145
    ActLu Primitive
146
    """
147

148
149
150
151
152
153
154
155
156
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        3,
        4,
        5,
        6,
        7,
        8,
157
        9,
158
159
160
161
162
        10,
        11,
        12,
        13,
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
163
164
165
166
    inner_primitive = None
    outer_primitive = None

    @staticmethod
167
168
169
    def abstract(
        x_aval,
        scale_aval,
170
        amax_aval,
171
172
173
174
175
176
177
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
178
        act_params,
179
180
181
182
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
183
    ):
184
        """
185
        te_act_lu_p abstract
186
        """
187
188
189
190
        del act_enum, act_params, amax_scope, transpose_batch_sequence
        assert (
            not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value
        ), f"scaling_mode = {scaling_mode}"
191
192
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
193
        assert scale_aval is None or scale_aval.dtype == jnp.float32
194
        assert amax_aval is None or amax_aval.dtype == jnp.float32
195
196
197
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
198
        )
199

200
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
201
202
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
203
204
        )

205
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
206
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
207

208
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
209

210
211
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
212
213
214
215
216
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
217
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
218
219
220
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
221
222

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
223
224

    @staticmethod
225
226
227
228
    def lowering(
        ctx,
        x,
        scale,
229
        amax,
230
231
232
233
234
235
236
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
237
        act_params,
238
239
240
241
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
242
    ):
243
        """
244
        te_gated_act_lu_p lowering rules
245
        """
246
247
        del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence
        x_aval, scale_aval, amax_aval = ctx.avals_in
248
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
249
        assert scale_aval is None or scale_aval.dtype == jnp.float32
250
251
252
253
254
255
        assert amax_aval.dtype == jnp.float32

        out = ffi.ffi_lowering(
            ActLuPrimitive.name,
            operand_output_aliases={2: 4},  # donate amax buffer to updated_amax
        )(
256
257
258
            ctx,
            x,
            scale,
259
            amax,
260
261
262
263
            act_enum=act_enum,
            scaling_mode=scaling_mode.value,
            is_2x=is_2x,
            act_params=act_params.to_ffi_lowering_dict(),
264
            output_amax_when_no_scaling=output_amax_when_no_scaling,
265
        )
266
        return out
267
268

    @staticmethod
269
270
271
    def impl(
        x,
        scale,
272
        amax,
273
274
275
276
277
278
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
279
        act_params,
280
281
282
283
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
284
285
286
287
288
    ):
        """
        to describe implementation
        """
        del is_outer
289
        assert ActLuPrimitive.inner_primitive is not None
290
291
292
293
294

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
295
                amax,
296
297
298
299
300
301
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
302
                act_params=act_params,
303
304
305
306
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=False,
307
308
309
310
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
311
312
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
313
314
315
316
317
318
319
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
320

321
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
322
323

    @staticmethod
324
325
326
327
328
329
330
331
332
333
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
334
        act_params,
335
336
337
338
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
339
    ):
340
        """
341
        to describe batch rules for vmap
342
343
344
        """
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
345
346
        x, scale, amax = batched_args
        x_bdim, scale_bdim, _ = batch_dims
347
        amax_bdim = scale_bdim
348

349
350
351
352
353
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
354
                amax,
355
356
                out_dtype=out_dtype,
                act_enum=act_enum,
357
                act_len=act_len,
358
359
360
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
361
                act_params=act_params,
362
363
364
365
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=is_outer,
366
367
368
            ),
            out_bdims,
        )
369
370

    @staticmethod
371
372
373
374
375
376
377
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
378
        act_params,
379
380
381
382
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
383
384
385
386
387
388
389
390
391
392
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
393
            act_params,
394
395
396
397
            amax_scope,
            transpose_batch_sequence,
            output_amax_when_no_scaling,
            is_outer,
398
        )  # Unused.
399
        x_spec = get_padded_spec(arg_infos[0])
400
401
402
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
403
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
404

405
        if is_2x:
406
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
407
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
408
409
410
411
412
413
414
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
415
416

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
417
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
418
            scale_inv_spec = amax_spec = scale_spec
419
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
420
421
422
423
424
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

425
        scale_inv_sharding = NamedSharding(
426
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
427
        )
428
429
430
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
431
        )
432

433
434
435
436
437
438
439
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
440
441

    @staticmethod
442
443
444
445
446
447
448
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
449
        act_params,
450
451
452
453
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
454
455
456
457
        mesh,
        arg_infos,
        result_infos,
    ):
458
        del result_infos, is_outer
459
        x_spec = get_padded_spec(arg_infos[0])
460
461
462
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
463
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
464

465
        if is_2x:
466
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
467
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
468
469
470
471
472
473
474
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
475
476

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
477
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
478
            scale_inv_spec = amax_spec = scale_spec
479
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
480
481
482
483
484
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

485
        scale_inv_sharding = NamedSharding(
486
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
487
        )
488
489
490
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
491
        )
492
493

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
494
495
496
497
498
499
500
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
501

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        def sharded_impl(x, scale, amax):
            (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                local_updated_amax,
            ) = ActLuPrimitive.impl(
                x,
                scale,
                amax,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                act_params=act_params,
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=True,
524
            )
525

526
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
527
528
529
530
531
532
533
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(
                    local_updated_amax, mesh
                )
            elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
                global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                    local_updated_amax, out_spec, transpose_batch_sequence, mesh
                )
534
            else:
535
                global_updated_amax = local_updated_amax
536
537
538
539
540
541
542
543

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
544

545
        return mesh, sharded_impl, out_shardings, arg_shardings
546

547
548
549
550
551
552
553
554
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
555
        act_params,
556
557
558
559
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
560
561
562
563
        mesh,
        value_types,
        result_types,
    ):
564
565
566
567
568
569
570
571
572
573
574
575
576
        del (
            out_dtype,
            act_enum,
            act_len,
            scale_dtype,
            act_params,
            amax_scope,
            transpose_batch_sequence,
            output_amax_when_no_scaling,
            is_outer,
            mesh,
            result_types,
        )
577
578
579
580
        prefix = "ActLu_"
        input_shape = value_types[0].shape
        output_shape = input_shape[:-2] + input_shape[-1:]
        # Here we pass len of output so that the scales are propagated correctly
581
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
582
            output_shape, unique_var=prefix + "x", flatten_axis=-1
583
        )
584
585
586
587
        x_axes = scale_rules.input_spec
        # Correct input spec with act dim
        x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:]
        out = scale_rules.input_spec
588

Alp Dener's avatar
Alp Dener committed
589
590
        colwise_out = (prefix + "out_colwise",)
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
591
        if is_2x:
Alp Dener's avatar
Alp Dener committed
592
            colwise_scale_inv = scale_rules.colwise_rule
593
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
594
                colwise_out = multidim_transpose(out, transpose_axis=-1)
595
596
            else:
                colwise_out = out
597
                colwise_scale_inv = scale_rules.colwise_rule
598

Alp Dener's avatar
Alp Dener committed
599
        amax = (prefix + "amax",)
600
601
602
603

        return SdyShardingRule(
            (
                x_axes,
Alp Dener's avatar
Alp Dener committed
604
                ("…1",),
605
                amax,
606
            ),
607
608
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
609
610
        )

611

612
register_primitive(ActLuPrimitive)
613
614


615
# TODO(Jeremy): replace is_2x with q_layout
616
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
617
    """
618
    DActLu DBias Cast Transpose Primitive
619
    """
620

621
622
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
623
624
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer
    impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
625
626
627
628
    inner_primitive = None
    outer_primitive = None

    @staticmethod
629
630
631
632
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
633
        amax_aval,
634
635
636
637
638
639
640
641
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
642
        act_params,
643
644
645
646
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
647
    ):
648
        """
649
        te_dact_dbias_quantize_p abstract
650
        """
651
        del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling
652
653
654
655
656
657
658
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
659
        assert scale_aval.dtype == jnp.float32
660
        assert amax_aval.dtype == jnp.float32
661
662
663
664
665
666

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

667
        ir_hidden_size = dz_aval.shape[-1]
668
        gi_hidden_size = act_len * x_aval.shape[-1]
669
        assert act_len * ir_hidden_size == gi_hidden_size
670
671
672
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
673
674
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
675

676
677
678
679
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
680
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
681
        if is_2x:
682
            if ScalingMode(scaling_mode).is_tensor_scaling():
683
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
684
            else:
685
686
687
688
689
690
691
692
693
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
694

695
        if is_dbias:
696
            dbias_shape = (act_len, ir_hidden_size)
697
698
699
700
701
702
703
704
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
705
706
707
708
709
710
711
712
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
713

714
715
716
717
718
719
720
721
722
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
723
724

    @staticmethod
725
    def outer_abstract(*args, **kwargs):
726
        """
727
        te_dact_dbias_quantize_p outer abstract
728
        """
729
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
730
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
731
732
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
733
734

    @staticmethod
735
736
737
738
739
    def lowering(
        ctx,
        dz,
        x,
        scale,
740
        amax,
741
742
743
744
745
746
747
748
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
749
        act_params,
750
751
752
753
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
754
    ):
755
        """
756
        te_dact_dbias_quantize_p lowering rules
757
        """
758
759
760
761
762
763
764
765
766
        del (
            out_dtype,
            scale_dtype,
            act_len,
            is_outer,
            amax_scope,
            transpose_batch_sequence,
        )
        dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in
767
768
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
769
770
771
772
773
        assert scale_aval.dtype == amax_aval.dtype == jnp.float32
        return ffi.ffi_lowering(
            BaseDActLuDBiasQuantizePrimitive.name,
            operand_output_aliases={3: 4},  # donate amax buffer to updated_amax
        )(
774
775
776
777
            ctx,
            dz,
            x,
            scale,
778
            amax,
779
            scaling_mode=scaling_mode.value,
780
781
782
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
783
            act_params=act_params.to_ffi_lowering_dict(),
784
            output_amax_when_no_scaling=output_amax_when_no_scaling,
785
        )
786
787

    @staticmethod
788
789
790
791
    def impl(
        dz,
        x,
        scale,
792
        amax,
793
794
795
796
797
798
799
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
800
        act_params,
801
802
803
804
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
805
    ):
806
        """
807
        te_dact_dbias_quantize_p impl
808
        """
809
        del is_outer
810
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
811
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
812
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
813
814
815
                dz,
                x,
                scale,
816
                amax,
817
818
819
820
821
822
823
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
824
                act_params=act_params,
825
826
827
828
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=False,
829
830
831
832
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
833
834
835
836
837
838
839
840
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
841
            )
842
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
843
844

    @staticmethod
845
846
847
848
849
850
851
852
853
854
855
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
856
        act_params,
857
858
859
860
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
861
    ):
862
        """
863
        to describe batch rules for vmap
864
        """
865
        check_valid_batch_dims(batch_dims)
866
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
867
868
        dz, x, scale, amax = batched_args
        _, x_bdim, scale_bdim, _ = batch_dims
869
870
871
872
873
874
875
876
877
878

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
879
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
880
881
882
                dz,
                x,
                scale,
883
                amax,
884
885
886
887
888
889
890
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
891
                act_params=act_params,
892
893
894
895
                amax_scope=amax_scope,
                transpose_batch_sequence=transpose_batch_sequence,
                output_amax_when_no_scaling=output_amax_when_no_scaling,
                is_outer=is_outer,
896
897
898
            ),
            out_bdims,
        )
899
900

    @staticmethod
901
902
903
904
905
906
907
908
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
909
        act_params,
910
911
912
913
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
914
915
916
917
        mesh,
        arg_infos,
        result_infos,
    ):
918
919
920
        del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling
        del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence

921
        x_spec = get_padded_spec(arg_infos[1])
922
        scale_spec = get_padded_spec(arg_infos[2])
923

924
925
926
927
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

928
        out_sharding = NamedSharding(
929
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
930
931
        )
        if is_2x:
932
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
933
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
934
935
936
937
938
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
939
940
941
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
942
943
        )

944
945
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
946
            mesh,
947
            PartitionSpec(*dbias_spec),
948
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
949
        )
950
951

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
952
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
953
            scale_inv_spec = amax_spec = scale_spec
954
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
955
956
957
958
959
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

960
        scale_inv_sharding = NamedSharding(
961
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
962
963
        )
        amax_sharding = NamedSharding(
964
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
965
        )
966
967
968
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
969
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
970
971
972
973
974
975
976
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
977
            dbias_sharding,
978
979
980
981
982
983
984
985
986
987
988
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
989
        act_params,
990
991
992
993
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
994
995
996
997
998
999
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
1000
1001
1002
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
1003
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
1004
1005
        )

1006
        if is_2x:
1007
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1008
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
1009
1010
1011
1012
1013
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
1014
1015
1016
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
1017
1018
        )

1019
1020
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
1021
            mesh,
1022
            PartitionSpec(*dbias_spec),
1023
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
1024
        )
1025
1026

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
1027
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1028
            scale_inv_spec = amax_spec = scale_spec
1029
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
1030
1031
1032
1033
1034
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

1035
        scale_inv_sharding = NamedSharding(
1036
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
1037
        )
1038
1039
1040
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
1041
1042
        )

1043
1044
1045
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
1046
1047
1048
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
1049
1050
        )
        arg_shardings = tuple(arg_shardings)
1051
1052
1053
1054
1055
1056
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
1057
            dbias_sharding,
1058
        )
1059

1060
1061
        def sharded_impl(dz, x, scale, amax):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = (
1062
                BaseDActLuDBiasQuantizePrimitive.impl(
1063
1064
1065
                    dz,
                    x,
                    scale,
1066
                    amax,
1067
1068
1069
1070
1071
1072
1073
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
1074
                    act_params=act_params,
1075
1076
1077
1078
                    output_amax_when_no_scaling=output_amax_when_no_scaling,
                    amax_scope=amax_scope,
                    transpose_batch_sequence=transpose_batch_sequence,
                    is_outer=True,
1079
1080
1081
1082
1083
1084
1085
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

1086
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
1087
1088
1089
1090
1091
1092
1093
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(
                    local_updated_amax, mesh
                )
            elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling:
                global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
                    local_updated_amax, x_spec, transpose_batch_sequence, mesh
                )
1094
            else:
1095
                global_updated_amax = local_updated_amax
1096
1097

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
1098
1099
1100

        return mesh, sharded_impl, out_shardings, arg_shardings

1101
1102
1103
1104
1105
1106
1107
1108
1109
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
1110
        act_params,
1111
1112
1113
1114
        amax_scope,
        transpose_batch_sequence,
        output_amax_when_no_scaling,
        is_outer,
1115
1116
1117
1118
        mesh,
        value_types,
        result_types,
    ):
1119

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
        del (
            out_dtype,
            scale_dtype,
            act_enum,
            act_len,
            act_params,
            is_outer,
            output_amax_when_no_scaling,
            mesh,
            result_types,
            amax_scope,
            transpose_batch_sequence,
        )

1134
        prefix = "DActLuDBias_"
1135
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
1136
            value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
1137
1138
        )
        x_axes = scale_rules.input_spec
Alp Dener's avatar
Alp Dener committed
1139
        dz_axes = (*x_axes[:-2], x_axes[-1])
1140
        out = x_axes
1141

Alp Dener's avatar
Alp Dener committed
1142
        colwise_out = (prefix + "out_colwise",)
1143
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
1144
1145
1146
1147
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
Alp Dener's avatar
Alp Dener committed
1148
                colwise_out = out
1149
                colwise_scale_inv = scale_rules.colwise_rule
1150

Alp Dener's avatar
Alp Dener committed
1151
1152
        dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
1153
1154

        return SdyShardingRule(
1155
            (dz_axes, x_axes, ("…2",), amax),
1156
1157
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
1158
1159
        )

1160

1161
1162
1163
1164
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1165
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1166
1167
1168


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1169
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1170
1171


1172
1173
1174
def _jax_act_lu(
    inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None
) -> Union[NoScaleTensor, ScaledTensor]:
1175
    """
1176
    JAX native activation implementation
1177
    """
1178
    act_params = act_params if act_params is not None else ActivationParams()
1179
1180
1181
1182
1183
1184
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )
    x = jnp.split(inputs, act_len, axis=-2)
1185
1186
    acts = []
    for idx, act_fn in enumerate(activation_type):
1187
        x_i = _convert_to_activation_function(act_fn, act_params)(x[idx])
1188
1189
        acts.append(x_i)
    x = reduce(operator.mul, acts)
1190
    x = jnp.squeeze(x, axis=-2)
1191
    if quantizer:
1192
        return quantizer.quantize(x, flatten_axis=-1)
1193
    return NoScaleTensor(data=x, amax=None)
1194
1195


1196
def _jax_quantize_dact_dbias(
1197
    dz: Union[jnp.ndarray, NoScaleTensor],
1198
1199
1200
1201
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1202
    act_params: Optional[ActivationParams] = None,
1203
):
1204
    """
1205
    JAX implementation of dact_lu and dbias with optional quantization
1206
    """
1207
    act_params = act_params if act_params is not None else ActivationParams()
1208
1209
1210
1211
1212
1213
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1214
    _, vjp_func = jax.vjp(
1215
1216
        partial(_jax_act_lu, activation_type=activation_type, act_params=act_params),
        x.astype(jnp.float32),
1217
    )
1218
1219
1220
    # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
    dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
    (dx,) = vjp_func(dz)
1221

1222
1223
    dbias = None
    if is_dbias:
1224
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
1225

1226
    if quantizer is not None:
1227
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
1228
1229
    else:
        dx = dx.astype(x.dtype)
1230
        dx = NoScaleTensor(data=dx, amax=None)
1231

1232
    return dx, dbias
1233
1234


1235
1236
1237
1238
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
1239
    act_params: Optional[ActivationParams] = None,
1240
    amax_scope: AmaxScope = AmaxScope.LOCAL,
1241
1242
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1243
1244
1245
1246
1247
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
1248
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1249
1250
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
1251
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
1252
1253
1254
1255
1256
1257
1258

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
1259
1260
    # TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl()
    # Do the same with dact_dbias_quantize() and layernorm_fwd()
1261
    act_type_id = ActivationEnum[activation_type].value
1262
1263
1264
1265
1266
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1267
    act_params = act_params if act_params is not None else ActivationParams()
1268
    if not ActLuPrimitive.enabled():
1269
        return _jax_act_lu(x, activation_type, quantizer, act_params)
1270

1271
    # TE/common does not support colwise-only quantization yet
1272
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1273
        return _jax_act_lu(x, activation_type, quantizer, act_params)
1274
1275
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
1276
1277
1278
1279
1280
1281
1282
1283
        f=act_lu,
        x=x,
        activation_type=activation_type,
        quantizer=quantizer,
        act_params=act_params,
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
1284
1285
1286
1287
1288
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1289
    output_shape = (*x.shape[:-2], x.shape[-1])
1290
1291
    amax = jnp.zeros((1,), jnp.float32)  # need to init with zero and shape=(1,)

1292
    if quantizer is None:
1293
        out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind(
1294
1295
            x,
            scale,
1296
            amax,
1297
1298
            out_dtype=x.dtype,
            act_enum=act_type_id,
1299
            act_len=act_len,
1300
            scaling_mode=ScalingMode.NO_SCALING.value,
1301
1302
            is_2x=False,
            scale_dtype=jnp.float32,
1303
            act_params=act_params,
1304
1305
1306
1307
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
            is_outer=True,
1308
        )
1309
        out = out.reshape(output_shape)
1310
        # TODO(Phuong): ScaledTensorFactory to create NoScaledTensor
1311
1312
        out = NoScaleTensor(
            data=out,
1313
            amax=updated_amax if output_amax_when_no_scaling else None,
1314
        )
1315
        return out
1316

1317
1318
1319
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
1320
            x=x,
1321
1322
            activation_type=activation_type,
            quantizer=None,
1323
            act_params=act_params,
1324
1325
1326
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=True,
1327
        )
1328
1329
1330
1331
1332
1333
        out, _ = _quantize_dbias_impl(
            out,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            amax_scope=amax_scope,
1334
            transpose_batch_sequence=transpose_batch_sequence,
1335
        )
1336
        return out
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
1349
        amax,
1350
1351
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1352
        act_len=act_len,
1353
1354
1355
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
1356
        act_params=act_params,
1357
1358
1359
1360
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
        is_outer=True,
1361
    )
1362

1363
1364
1365
1366
1367
1368
1369
1370
1371
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1372
1373
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1374
    )
1375
1376


1377
1378
1379
1380
1381
1382
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1383
    act_params: Optional[ActivationParams] = None,
1384
1385
1386
    amax_scope: AmaxScope = AmaxScope.LOCAL,
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1387
1388
1389
1390
1391
1392
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1393
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1394
1395
1396
1397
1398
1399
1400
1401
1402
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1403
    act_params = act_params if act_params is not None else ActivationParams()
1404
1405
1406
1407
1408
1409
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1410
1411
    scale = jnp.empty((1,), jnp.float32)
    amax = jnp.zeros((1,), jnp.float32)  # need to init with zero and shape=(1,)
Alp Dener's avatar
Alp Dener committed
1412
    act_type_id = ActivationEnum[activation_type]
1413
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1414
1415
1416
    if not PrimitiveClass.enabled() or (
        quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
    ):
1417
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
Alp Dener's avatar
Alp Dener committed
1418
    if quantizer is None:
1419
        output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind(
Alp Dener's avatar
Alp Dener committed
1420
1421
1422
            dz,
            x,
            scale,
1423
            amax,
Alp Dener's avatar
Alp Dener committed
1424
1425
1426
1427
1428
1429
1430
1431
1432
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
1433
            act_params=act_params,
1434
1435
1436
1437
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
            is_outer=True,
Alp Dener's avatar
Alp Dener committed
1438
1439
1440
1441
1442
1443
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

1444
1445
        output = NoScaleTensor(
            data=output,
1446
            amax=updated_amax if output_amax_when_no_scaling else None,
1447
        )
Alp Dener's avatar
Alp Dener committed
1448
        return output, dbias
1449

1450
1451
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1452
        out = dact_lu(
1453
1454
1455
1456
1457
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type,
            quantizer=None,
            act_params=act_params,
1458
1459
1460
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
1461
1462
        )
        return _quantize_dbias_impl(
1463
1464
1465
1466
1467
1468
1469
            out.data,
            quantizer,
            is_dbias=True,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1470
        )
1471

1472
    is_gated = act_len == 2
1473
1474
1475
1476
1477
1478
1479
1480
1481
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1482
            flatten_axis=-2,
1483
            act_params=act_params,
1484
1485
1486
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=output_amax_when_no_scaling,
1487
1488
1489
1490
        )
        if war_output is not None:
            return war_output

1491
1492
1493
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
1494
1495
            dz=dz,
            x=x,
1496
1497
            activation_type=activation_type,
            quantizer=None,
1498
            act_params=act_params,
1499
1500
1501
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
            output_amax_when_no_scaling=True,
1502
1503
        )
        out, dbias = _quantize_dbias_impl(
1504
1505
1506
1507
1508
1509
1510
            out,
            is_dbias=is_dbias,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1511
1512
1513
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1514
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1515
1516
1517
1518
1519
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
1520
1521
1522
1523
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type=activation_type,
            act_params=act_params,
1524
1525
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1526
        )
1527
        out, dbias = _quantize_dbias_impl(
1528
1529
1530
1531
1532
1533
1534
            dgated,
            quantizer,
            is_dbias=True,
            dq_dtype=x.dtype,
            flatten_axis=-2,
            amax_scope=amax_scope,
            transpose_batch_sequence=transpose_batch_sequence,
1535
        )
1536
1537
1538
1539
1540
1541
1542
1543
1544
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1545
    ) = PrimitiveClass.outer_primitive.bind(
1546
1547
1548
        dz,
        x,
        scale,
1549
        amax,
1550
1551
1552
1553
1554
1555
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1556
        act_len=act_len,
1557
        act_params=act_params,
1558
1559
1560
1561
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
        is_outer=True,
1562
    )
1563

1564
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1565
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1577
1578
1579
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1580
    )
1581

1582
    return out, dbias
1583
1584


1585
1586
def dact_lu(
    dz: jnp.ndarray,
1587
1588
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1589
    quantizer: Optional[Quantizer] = None,
1590
    act_params: Optional[ActivationParams] = None,
1591
1592
1593
    amax_scope: AmaxScope = AmaxScope.LOCAL,
    transpose_batch_sequence: bool = False,
    output_amax_when_no_scaling: bool = False,
1594
) -> Union[jnp.ndarray, ScaledTensor]:
1595
    """
1596
    Backward pass for activation with optional quantization.
1597

1598
1599
1600
1601
1602
1603
1604
1605
1606
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
1607
    act_params = act_params if act_params is not None else ActivationParams()
1608
1609
1610
1611
1612
1613
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1614
        act_params=act_params,
1615
1616
1617
        amax_scope=amax_scope,
        transpose_batch_sequence=transpose_batch_sequence,
        output_amax_when_no_scaling=output_amax_when_no_scaling,
1618
    )
1619
    return output