activation.py 46.3 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
5
from typing import Sequence, Union, Callable, Optional, Tuple
6
import operator
7
from functools import reduce, partial
8
from dataclasses import dataclass
9

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
from jax.experimental.custom_partitioning import SdyShardingRule
14
from jax.sharding import PartitionSpec
15

16
import numpy as np
17
18
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
19
20
21
from .base import BasePrimitive, register_primitive
from .misc import (
    jax_dtype_to_te_dtype,
22
    te_dtype_to_jax_dtype,
23
    get_padded_spec,
24
25
26
27
28
29
    check_valid_batch_dims,
    multidim_transpose,
    try_apply_delayed_scaling_2x_war,
    should_apply_1x_fused_dbias_war_for_arch_l_100,
    NamedSharding,
)
30
from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
31
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
32
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
33
34
from ..quantize import (
    Quantizer,
35
    QuantizeLayout,
36
37
    DelayedScaleQuantizer,
    ScalingMode,
38
39
)

40

41
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
42
43
44


ActivationEnum = {
45
46
47
48
49
50
51
52
53
54
    ("gelu",): NVTE_Activation_Type.GELU,
    ("gelu", "linear"): NVTE_Activation_Type.GEGLU,
    ("silu",): NVTE_Activation_Type.SILU,
    ("silu", "linear"): NVTE_Activation_Type.SWIGLU,
    ("relu",): NVTE_Activation_Type.RELU,
    ("relu", "linear"): NVTE_Activation_Type.REGLU,
    ("quick_gelu",): NVTE_Activation_Type.QGELU,
    ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU,
    ("squared_relu",): NVTE_Activation_Type.SRELU,
    ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU,
55
    ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU,
56
57
58
}


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
@dataclass(frozen=True)
class ClampedSwigluParams:
    """Parameters for the Clamped SwiGLU activation function
    used in GPT OSS."""

    limit: float = 7.0
    alpha: float = 1.702

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work.

        Returns:
            int: Hash value of the dataclass instance.
        """
        return hash((self.limit, self.alpha))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.

        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)}


@dataclass(frozen=True)
class ActivationParams:
    """Parameters for various activation functions.
    Currently only Clamped SwiGLU activation has parameters.
    """

    clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams()

    @staticmethod
    def create(activation_type, **kwargs):
        """Factory method to create ActivationParams based on activation_type."""
        CLAMPED_ACTIVATION_TYPES = {
            ("clamped_silu", "clamped_linear"),
            "clamped_silu",
            "clamped_linear",
        }
        if activation_type in CLAMPED_ACTIVATION_TYPES:
            return ActivationParams(ClampedSwigluParams(**kwargs))
        return ActivationParams()  # Default params for activations without parameters

    def __hash__(self):
        """Custom hash function to ensure dataclass is hashable for jax jit to work"""
        return hash((self.clamped_swiglu,))

    def to_ffi_lowering_dict(self):
        """Convert the activation parameters to a dictionary format for FFI lowering.
        Returns:
            dict: A dictionary representation of the activation parameters consumable by
            XLA FFI bindings for activation functions.
        """
        return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()}


def _convert_to_activation_function(fn_or_string, act_params: ActivationParams):
119
120
121
    """Convert a string to an activation function."""
    if fn_or_string == "linear":
        return lambda x: x
122
123
124
125
126
127
    if fn_or_string == "clamped_linear":
        # This function is used for ClampedSwiGLU
        # used in GPT OSS where the gates are not only clamped
        # but also shifted by +1
        limit = act_params.clamped_swiglu.limit
        return lambda x: jnp.clip(x, min=-limit, max=limit) + 1
128
129
130
131
    if fn_or_string == "quick_gelu":
        return lambda x: jax.nn.sigmoid(1.702 * x) * x
    if fn_or_string == "squared_relu":
        return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
132
133
134
135
    if fn_or_string == "clamped_silu":
        limit = act_params.clamped_swiglu.limit
        alpha = act_params.clamped_swiglu.alpha
        return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit)
136
137
138
139
140
141
142
    if isinstance(fn_or_string, str):
        return getattr(jax.nn, fn_or_string)
    if callable(fn_or_string):
        return fn_or_string
    raise ValueError(f"Unsupported {fn_or_string} to an activation function")


143
144
class ActLuPrimitive(BasePrimitive):
    """
145
    ActLu Primitive
146
    """
147

148
149
150
151
152
153
154
155
156
157
    name = "te_act_lu_ffi"
    multiple_results = True
    impl_static_args = (
        2,
        3,
        4,
        5,
        6,
        7,
        8,
158
159
        9,
    )  # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params
160
161
162
163
    inner_primitive = None
    outer_primitive = None

    @staticmethod
164
165
166
167
168
169
170
171
172
173
174
    def abstract(
        x_aval,
        scale_aval,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
175
        act_params,
176
    ):
177
        """
178
        te_act_lu_p abstract
179
        """
180
        del act_enum, act_params
181
182
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
183
        assert scale_aval is None or scale_aval.dtype == jnp.float32
184
185
186
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
187
        )
188

189
        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
190
191
            "Current tensor scaling is not yet supported for fused activation and quantization."
            " Please do activation in higher-precision then quantize with current tensor scaling."
192
193
        )

194
        out_shape = (*x_aval.shape[:-2], x_aval.shape[-1])  # Exclude act dim
195
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
196

197
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
198

199
200
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
201
202
203
204
205
        ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
        if not is_2x:
            out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
206
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
207
208
209
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
210
211

        return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
212
213

    @staticmethod
214
215
216
217
218
219
220
221
222
223
224
225
    def lowering(
        ctx,
        x,
        scale,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
226
        act_params,
227
    ):
228
        """
229
        te_gated_act_lu_p lowering rules
230
        """
231
        del out_dtype, scale_dtype, act_len, is_outer
232
        x_aval, scale_aval = ctx.avals_in
233
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
234
235
        assert scale_aval is None or scale_aval.dtype == jnp.float32
        out = ffi.ffi_lowering(ActLuPrimitive.name)(
236
237
238
239
240
241
242
            ctx,
            x,
            scale,
            act_enum=act_enum,
            scaling_mode=scaling_mode.value,
            is_2x=is_2x,
            act_params=act_params.to_ffi_lowering_dict(),
243
        )
244
        return out
245
246

    @staticmethod
247
248
249
250
251
252
253
254
255
256
    def impl(
        x,
        scale,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
257
        act_params,
258
259
260
261
262
    ):
        """
        to describe implementation
        """
        del is_outer
263
        assert ActLuPrimitive.inner_primitive is not None
264
265
266
267
268
269
270
271
272
273
274
275

        out, colwise_out, scale_inv, colwise_scale_inv, updated_amax = (
            ActLuPrimitive.inner_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                act_len=act_len,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_outer=False,
276
                act_params=act_params,
277
278
279
280
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
281
282
        ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
        # Slice out padding for MXFP8, noop for DelayedScaling
283
284
285
286
287
288
289
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
            )
290

291
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
292
293

    @staticmethod
294
295
296
297
298
299
300
301
302
303
304
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
305
        act_params,
306
    ):
307
        """
308
        to describe batch rules for vmap
309
        """
310
        del act_len, is_outer
311
312
        check_valid_batch_dims(batch_dims)
        assert ActLuPrimitive.outer_primitive is not None
313
314
315
        x, scale = batched_args
        x_bdim, scale_bdim = batch_dims
        amax_bdim = scale_bdim
316

317
318
319
320
321
322
323
324
325
326
        out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim
        return (
            ActLuPrimitive.outer_primitive.bind(
                x,
                scale,
                out_dtype=out_dtype,
                act_enum=act_enum,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
327
                act_params=act_params,
328
329
330
            ),
            out_bdims,
        )
331
332

    @staticmethod
333
334
335
336
337
338
339
340
    def infer_sharding_from_operands(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
341
        act_params,
342
343
344
345
346
347
348
349
350
351
352
        mesh,
        arg_infos,
        result_infos,
    ):
        del (
            out_dtype,
            result_infos,
            act_enum,
            scale_dtype,
            act_len,
            is_outer,
353
            act_params,
354
        )  # Unused.
355
        x_spec = get_padded_spec(arg_infos[0])
356
357
358
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
359
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
360

361
        if is_2x:
362
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
363
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
364
365
366
367
368
369
370
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
371
372

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
373
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
374
            scale_inv_spec = amax_spec = scale_spec
375
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
376
377
378
379
380
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

381
        scale_inv_sharding = NamedSharding(
382
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
383
        )
384
385
386
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
387
        )
388

389
390
391
392
393
394
395
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
396
397

    @staticmethod
398
399
400
401
402
403
404
405
    def partition(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
406
        act_params,
407
408
409
410
411
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer  # Unused.
412
        x_spec = get_padded_spec(arg_infos[0])
413
414
415
        scale_spec = get_padded_spec(arg_infos[1])

        out_spec = (*x_spec[:-2], x_spec[-1])
416
        out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
417

418
        if is_2x:
419
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
420
                colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
421
422
423
424
425
426
427
            else:
                colwise_out_spec = out_spec
        else:
            colwise_out_spec = (None,)
        colwise_out_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
        )
428
429

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
430
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
431
            scale_inv_spec = amax_spec = scale_spec
432
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
433
434
435
436
437
            scale_inv_spec = out_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

438
        scale_inv_sharding = NamedSharding(
439
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
440
        )
441
442
443
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
444
        )
445
446

        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
447
448
449
450
451
452
453
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
        )
454

455
456
457
458
459
460
461
462
463
464
465
466
        def sharded_impl(x, scale):
            local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = (
                ActLuPrimitive.impl(
                    x,
                    scale,
                    out_dtype=out_dtype,
                    act_enum=act_enum,
                    act_len=act_len,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_outer=True,
467
                    act_params=act_params,
468
469
                )
            )
470

471
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
472
473
474
475
476
477
478
479
480
481
482
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return (
                local_x,
                local_colwise_x,
                local_scale_inv,
                local_colwise_scale_inv,
                global_updated_amax,
            )
483

484
        return mesh, sharded_impl, out_shardings, arg_shardings
485

486
487
488
489
490
491
492
493
494
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        act_enum,
        act_len,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_outer,
495
        act_params,
496
497
498
499
        mesh,
        value_types,
        result_types,
    ):
500
        del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params
501
502
503
504
        prefix = "ActLu_"
        input_shape = value_types[0].shape
        output_shape = input_shape[:-2] + input_shape[-1:]
        # Here we pass len of output so that the scales are propagated correctly
505
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
506
            output_shape, unique_var=prefix + "x", flatten_axis=-1
507
        )
508
509
510
511
        x_axes = scale_rules.input_spec
        # Correct input spec with act dim
        x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:]
        out = scale_rules.input_spec
512

Alp Dener's avatar
Alp Dener committed
513
514
        colwise_out = (prefix + "out_colwise",)
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
515
        if is_2x:
Alp Dener's avatar
Alp Dener committed
516
            colwise_scale_inv = scale_rules.colwise_rule
517
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
518
                colwise_out = multidim_transpose(out, transpose_axis=-1)
519
520
            else:
                colwise_out = out
521
                colwise_scale_inv = scale_rules.colwise_rule
522

Alp Dener's avatar
Alp Dener committed
523
        amax = (prefix + "amax",)
524
525
526
527

        return SdyShardingRule(
            (
                x_axes,
Alp Dener's avatar
Alp Dener committed
528
                ("…1",),
529
            ),
530
531
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
            **scale_rules.factor_sizes,
532
533
        )

534

535
register_primitive(ActLuPrimitive)
536
537


538
# TODO(Jeremy): replace is_2x with q_layout
539
class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
540
    """
541
    DActLu DBias Cast Transpose Primitive
542
    """
543

544
545
    name = "te_dact_dbias_quantize_ffi"
    multiple_results = True
546
547
    # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params
    impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
548
549
550
551
    inner_primitive = None
    outer_primitive = None

    @staticmethod
552
553
554
555
556
557
558
559
560
561
562
563
564
    def abstract(
        dz_aval,
        x_aval,
        scale_aval,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
565
        act_params,
566
    ):
567
        """
568
        te_dact_dbias_quantize_p abstract
569
        """
570
        del act_enum, act_params
571
572
573
574
575
576
577
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_dtype
        assert x_aval.shape[-2] == act_len, (
            "activation input should be replicated by act_len in the -2 axis, got input shape"
            f" {x_aval.shape} and act_len {act_len}"
        )
578
        assert scale_aval.dtype == jnp.float32
579
580
581
582
583
584

        assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
            "Current tensor scaling is not supported for fused dact and quantization. Please do"
            " dact in higher-precision then quantize with current tensor scaling."
        )

585
        ir_hidden_size = dz_aval.shape[-1]
586
        gi_hidden_size = act_len * x_aval.shape[-1]
587
        assert act_len * ir_hidden_size == gi_hidden_size
588
589
590
        assert (
            x_aval.shape[:-2] == dz_aval.shape[:-1]
        ), "dz and x should have the same leading dimensions"
591
592
        out_shape = x_aval.shape
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
593

594
595
596
597
        updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
598
        ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
599
        if is_2x:
600
            if ScalingMode(scaling_mode).is_tensor_scaling():
601
                colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
602
            else:
603
604
605
606
607
608
609
610
611
                colwise_out_shape = out_shape
        else:
            colwise_out_shape = (1,)
            colwise_scale_inv_shape = (1,)
        colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
        scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
        colwise_scale_inv_aval = jax.core.ShapedArray(
            shape=colwise_scale_inv_shape, dtype=scale_dtype
        )
612

613
        if is_dbias:
614
            dbias_shape = (act_len, ir_hidden_size)
615
616
617
618
619
620
621
622
            (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
                x_aval.size // gi_hidden_size,
                gi_hidden_size,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                scaling_mode,
                is_2x,
            )
623
624
625
626
627
628
629
630
            wkspace_shape = wkspace_info[0]
            wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
        else:
            dbias_shape = (1,)
            wkspace_shape = (1,)
            wkspace_dtype = jnp.float32
        dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
        wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
631

632
633
634
635
636
637
638
639
640
        return (
            out_aval,
            colwise_out_aval,
            scale_inv_aval,
            colwise_scale_inv_aval,
            updated_amax_aval,
            dbias_aval,
            wkspace_aval,
        )
641
642

    @staticmethod
643
    def outer_abstract(*args, **kwargs):
644
        """
645
        te_dact_dbias_quantize_p outer abstract
646
        """
647
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
648
            BaseDActLuDBiasQuantizePrimitive.abstract(*args, **kwargs)
649
650
        )
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
651
652

    @staticmethod
653
654
655
656
657
658
659
660
661
662
663
664
665
666
    def lowering(
        ctx,
        dz,
        x,
        scale,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
667
        act_params,
668
    ):
669
        """
670
        te_dact_dbias_quantize_p lowering rules
671
        """
672
        del out_dtype, scale_dtype, act_len, is_outer
673
674
675
676
        dz_aval, x_aval, scale_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert scale_aval.dtype == jnp.float32
677
        return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)(
678
679
680
681
            ctx,
            dz,
            x,
            scale,
682
            scaling_mode=scaling_mode.value,
683
684
685
            is_2x=is_2x,
            is_dbias=is_dbias,
            act_enum=int(act_enum),
686
            act_params=act_params.to_ffi_lowering_dict(),
687
        )
688
689

    @staticmethod
690
691
692
693
694
695
696
697
698
699
700
701
    def impl(
        dz,
        x,
        scale,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
702
        act_params,
703
    ):
704
        """
705
        te_dact_dbias_quantize_p impl
706
        """
707
        del is_outer
708
        assert BaseDActLuDBiasQuantizePrimitive.inner_primitive is not None
709
        (out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias, _) = (
710
            BaseDActLuDBiasQuantizePrimitive.inner_primitive.bind(
711
712
713
714
715
716
717
718
719
720
721
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
                is_outer=False,
722
                act_params=act_params,
723
724
725
726
            )
        )
        rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
            scaling_mode
727
728
729
730
731
732
733
734
        ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
        # Slice out padding for MXFP8, noop for DelayedScaling
        scale_inv = jax.lax.slice(
            scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
        )
        if is_2x:
            colwise_scale_inv = jax.lax.slice(
                colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
735
            )
736
        return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
737
738

    @staticmethod
739
740
741
742
743
744
745
746
747
748
749
750
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
751
        act_params,
752
    ):
753
        """
754
        to describe batch rules for vmap
755
        """
756
757
        del is_outer
        check_valid_batch_dims(batch_dims)
758
        assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None
759
760
761
762
763
764
765
766
767
768
769
770
        dz, x, scale = batched_args
        _, x_bdim, scale_bdim = batch_dims

        out_bdims = (
            x_bdim,  # rowwise output
            scale_bdim,  # rowwise scale_inv
            x_bdim,  # colwise output
            scale_bdim,  # colwise scale_inv
            scale_bdim,  # amax
            x_bdim,  # dbias
        )
        return (
771
            BaseDActLuDBiasQuantizePrimitive.outer_primitive.bind(
772
773
774
775
776
777
778
779
780
781
                dz,
                x,
                scale,
                out_dtype=out_dtype,
                scaling_mode=scaling_mode,
                is_2x=is_2x,
                scale_dtype=scale_dtype,
                is_dbias=is_dbias,
                act_enum=act_enum,
                act_len=act_len,
782
                act_params=act_params,
783
784
785
            ),
            out_bdims,
        )
786
787

    @staticmethod
788
789
790
791
792
793
794
795
796
    def infer_sharding_from_operands(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
797
        act_params,
798
799
800
801
        mesh,
        arg_infos,
        result_infos,
    ):
802
        del out_dtype, result_infos, act_enum, act_params
803
        del scale_dtype, act_len, is_outer
804
        x_spec = get_padded_spec(arg_infos[1])
805
        scale_spec = get_padded_spec(arg_infos[2])
806

807
808
809
810
        assert (
            scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
        ), "Partitioned current tensor scaling is not yet supported."

811
        out_sharding = NamedSharding(
812
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
813
814
        )
        if is_2x:
815
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
816
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
817
818
819
820
821
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
822
823
824
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
825
826
        )

827
828
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
829
            mesh,
830
            PartitionSpec(*dbias_spec),
831
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
832
        )
833
834

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
835
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
836
            scale_inv_spec = amax_spec = scale_spec
837
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
838
839
840
841
842
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

843
        scale_inv_sharding = NamedSharding(
844
            mesh, PartitionSpec(*scale_inv_spec), desc="BaseDActLuDBiasQuantizePrimitive.scale_inv"
845
846
        )
        amax_sharding = NamedSharding(
847
            mesh, PartitionSpec(*amax_spec), desc="BaseDActLuDBiasQuantizePrimitive.amax"
848
        )
849
850
851
        colwise_scale_inv_sharding = NamedSharding(
            mesh,
            PartitionSpec(*colwise_scale_inv_spec),
852
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_scale_inv",
853
854
855
856
857
858
859
        )
        return (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
860
            dbias_sharding,
861
862
863
864
865
866
867
868
869
870
871
872
        )

    @staticmethod
    def partition(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
873
        act_params,
874
875
876
877
878
879
        mesh,
        arg_infos,
        result_infos,
    ):
        del result_infos, is_outer
        x_spec = get_padded_spec(arg_infos[1])
880
881
882
        scale_spec = get_padded_spec(arg_infos[2])

        out_sharding = NamedSharding(
883
            mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out"
884
885
        )

886
        if is_2x:
887
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
888
                colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
889
890
891
892
893
            else:
                colwise_x_spec = x_spec
        else:
            colwise_x_spec = (None,)
        colwise_out_sharding = NamedSharding(
894
895
896
            mesh,
            PartitionSpec(*colwise_x_spec),
            desc="BaseDActLuDBiasQuantizePrimitive.colwise_out",
897
898
        )

899
900
        dbias_spec = x_spec[-2:] if is_dbias else (None,)
        dbias_sharding = NamedSharding(
901
            mesh,
902
            PartitionSpec(*dbias_spec),
903
            desc="BaseDActLuDBiasQuantizePrimitive.dbias",
904
        )
905
906

        scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
907
        if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
908
            scale_inv_spec = amax_spec = scale_spec
909
        elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
910
911
912
913
914
            scale_inv_spec = x_spec

        if is_2x:
            colwise_scale_inv_spec = scale_inv_spec

915
        scale_inv_sharding = NamedSharding(
916
            mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
917
        )
918
919
920
        amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
        colwise_scale_inv_sharding = NamedSharding(
            mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
921
922
        )

923
924
925
        arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
        # Ensure dz and x are partitioned the same way.
        arg_shardings[0] = NamedSharding(
926
927
928
            mesh,
            PartitionSpec(*x_spec[:-2], x_spec[-1]),
            desc="BaseDActLuDBiasQuantizePrimitive.dz",
929
930
        )
        arg_shardings = tuple(arg_shardings)
931
932
933
934
935
936
        out_shardings = (
            out_sharding,
            colwise_out_sharding,
            scale_inv_sharding,
            colwise_scale_inv_sharding,
            amax_sharding,
937
            dbias_sharding,
938
        )
939

940
941
        def sharded_impl(dz, x, scale):
            (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = (
942
                BaseDActLuDBiasQuantizePrimitive.impl(
943
944
945
946
947
948
949
950
951
952
953
                    dz,
                    x,
                    scale,
                    out_dtype=out_dtype,
                    scaling_mode=scaling_mode,
                    is_2x=is_2x,
                    scale_dtype=scale_dtype,
                    is_dbias=is_dbias,
                    act_enum=act_enum,
                    act_len=act_len,
                    is_outer=True,
954
                    act_params=act_params,
955
956
957
958
959
960
961
                )
            )
            if is_dbias:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            else:
                global_dbias = local_dbias

962
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
963
964
965
966
967
                global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
            else:
                global_updated_amax = local_amax

            return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias
968
969
970

        return mesh, sharded_impl, out_shardings, arg_shardings

971
972
973
974
975
976
977
978
979
980
    @staticmethod
    def shardy_sharding_rule(
        out_dtype,
        scaling_mode,
        is_2x,
        scale_dtype,
        is_dbias,
        act_enum,
        act_len,
        is_outer,
981
        act_params,
982
983
984
985
        mesh,
        value_types,
        result_types,
    ):
986
987

        del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params
988
        prefix = "DActLuDBias_"
989
        scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
990
            value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
991
992
        )
        x_axes = scale_rules.input_spec
Alp Dener's avatar
Alp Dener committed
993
        dz_axes = (*x_axes[:-2], x_axes[-1])
994
        out = x_axes
995

Alp Dener's avatar
Alp Dener committed
996
        colwise_out = (prefix + "out_colwise",)
997
        colwise_scale_inv = (prefix + "scale_inv_colwise",)
998
999
1000
1001
        if is_2x:
            if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
                colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
            else:
Alp Dener's avatar
Alp Dener committed
1002
                colwise_out = out
1003
                colwise_scale_inv = scale_rules.colwise_rule
1004

Alp Dener's avatar
Alp Dener committed
1005
1006
        dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
        amax = (prefix + "amax",)
1007
1008

        return SdyShardingRule(
Alp Dener's avatar
Alp Dener committed
1009
            (dz_axes, x_axes, ("…2",)),
1010
1011
            (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
            **scale_rules.factor_sizes,
1012
1013
        )

1014

1015
1016
1017
1018
register_primitive(BaseDActLuDBiasQuantizePrimitive)


class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1019
    """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1020
1021
1022


class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
1023
    """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
1024
1025


1026
1027
1028
def _jax_act_lu(
    inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None
) -> Union[NoScaleTensor, ScaledTensor]:
1029
    """
1030
    JAX native activation implementation
1031
    """
1032
    act_params = act_params if act_params is not None else ActivationParams()
1033
1034
1035
1036
1037
1038
    act_len = len(activation_type)
    assert inputs.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {inputs.shape} and act_len {act_len}"
    )
    x = jnp.split(inputs, act_len, axis=-2)
1039
1040
    acts = []
    for idx, act_fn in enumerate(activation_type):
1041
        x_i = _convert_to_activation_function(act_fn, act_params)(x[idx])
1042
1043
        acts.append(x_i)
    x = reduce(operator.mul, acts)
1044
    x = jnp.squeeze(x, axis=-2)
1045
    if quantizer:
1046
        return quantizer.quantize(x, flatten_axis=-1)
1047
    return NoScaleTensor(data=x, amax=None)
1048
1049


1050
def _jax_quantize_dact_dbias(
1051
    dz: Union[jnp.ndarray, NoScaleTensor],
1052
1053
1054
1055
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1056
    act_params: Optional[ActivationParams] = None,
1057
):
1058
    """
1059
    JAX implementation of dact_lu and dbias with optional quantization
1060
    """
1061
    act_params = act_params if act_params is not None else ActivationParams()
1062
1063
1064
1065
1066
1067
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

1068
    _, vjp_func = jax.vjp(
1069
1070
        partial(_jax_act_lu, activation_type=activation_type, act_params=act_params),
        x.astype(jnp.float32),
1071
    )
1072
1073
1074
    # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
    dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
    (dx,) = vjp_func(dz)
1075

1076
1077
    dbias = None
    if is_dbias:
1078
        dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
1079

1080
    if quantizer is not None:
1081
        dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
1082
1083
    else:
        dx = dx.astype(x.dtype)
1084
        dx = NoScaleTensor(data=dx, amax=None)
1085

1086
    return dx, dbias
1087
1088


1089
1090
1091
1092
def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
1093
    act_params: Optional[ActivationParams] = None,
1094
    amax_scope: AmaxScope = AmaxScope.LOCAL,
1095
1096
1097
1098
1099
) -> Union[jnp.ndarray, ScaledTensor]:
    """Activation with optional quantization.

    Args:
        x: Input tensor to be processed.
1100
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1101
1102
        activation_type: Type of activation function to apply.
        quantizer: Optional quantizer for FP8 quantization of the output.
1103
        amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
1104
1105
1106
1107
1108
1109
1110
1111

    Returns:
        If quantizer is None:
            The activated input tensor with the same dtype as input.
        If quantizer is provided:
            A ScaledTensor containing the quantized activated input.
    """
    act_type_id = ActivationEnum[activation_type].value
1112
1113
1114
1115
1116
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )
1117
    act_params = act_params if act_params is not None else ActivationParams()
1118
    if not ActLuPrimitive.enabled():
1119
        return _jax_act_lu(x, activation_type, quantizer, act_params)
1120

1121
    # TE/common does not support colwise-only quantization yet
1122
    if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
1123
        return _jax_act_lu(x, activation_type, quantizer, act_params)
1124
1125
    # TE/common does not support 2x quantization for DelayedScaling yet
    war_output = try_apply_delayed_scaling_2x_war(
1126
        f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params
1127
1128
1129
1130
1131
    )
    if war_output is not None:
        return war_output

    scale = jnp.empty((1,), jnp.float32)
1132
    output_shape = (*x.shape[:-2], x.shape[-1])
1133
1134
1135
1136
1137
1138
    if quantizer is None:
        out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
            x,
            scale,
            out_dtype=x.dtype,
            act_enum=act_type_id,
1139
            act_len=act_len,
1140
            scaling_mode=ScalingMode.NO_SCALING.value,
1141
1142
1143
            is_2x=False,
            scale_dtype=jnp.float32,
            is_outer=True,
1144
            act_params=act_params,
1145
        )
1146
        out = out.reshape(output_shape)
1147
1148
1149
1150
        out = NoScaleTensor(
            data=out,
            amax=None,
        )
1151
        return out
1152

1153
1154
1155
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = act_lu(
1156
            x=x,
1157
1158
            activation_type=activation_type,
            quantizer=None,
1159
            act_params=act_params,
1160
        )
1161
1162
1163
1164
1165
1166
1167
        out, _ = _quantize_dbias_impl(
            out,
            is_dbias=False,
            quantizer=quantizer,
            dq_dtype=x.dtype,
            amax_scope=amax_scope,
        )
1168
        return out
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
    if isinstance(quantizer, DelayedScaleQuantizer):
        scale = quantizer.scale

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
    ) = ActLuPrimitive.outer_primitive.bind(
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        act_enum=act_type_id,
1183
        act_len=act_len,
1184
1185
1186
1187
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_outer=True,
1188
        act_params=act_params,
1189
    )
1190

1191
1192
1193
1194
1195
1196
1197
1198
1199
    quantizer.update(updated_amax)

    return ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1200
1201
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
1202
    )
1203
1204


1205
1206
1207
1208
1209
1210
def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
    is_dbias: bool = True,
    quantizer: Optional[Quantizer] = None,
1211
    act_params: Optional[ActivationParams] = None,
1212
1213
1214
1215
1216
1217
) -> Tuple[ScaledTensor, jnp.ndarray]:
    """Compute gradients of activation and bias with optional quantization.

    Args:
        dz: Gradient of the output with respect to the activation output.
        x: Input tensor that was processed by the forward pass.
1218
            Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
1219
1220
1221
1222
1223
1224
1225
1226
1227
        activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
        is_dbias: If True, compute bias gradient. Defaults to True.
        quantizer: Optional quantizer for FP8 quantization of the output.

    Returns:
        Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
        - The gradient of the activation with respect to the input.
        - The gradient of the activation with respect to the bias.
    """
1228
    act_params = act_params if act_params is not None else ActivationParams()
1229
1230
1231
1232
1233
1234
    act_len = len(activation_type)
    assert x.shape[-2] == act_len, (
        "activation input should be replicated by act_len in the -2 axis, got input shape"
        f" {x.shape} and act_len {act_len}"
    )

Alp Dener's avatar
Alp Dener committed
1235
1236
    scale = jnp.empty((), jnp.float32)
    act_type_id = ActivationEnum[activation_type]
1237
    PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
Alp Dener's avatar
Alp Dener committed
1238
1239
1240
    if not PrimitiveClass.enabled() or (
        quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
    ):
1241
        return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params)
Alp Dener's avatar
Alp Dener committed
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
    if quantizer is None:
        output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
            dz,
            x,
            scale,
            # outputs float32 for dbias accumulation
            out_dtype=(jnp.float32 if is_dbias else x.dtype),
            # default value for no scaling, TE/common ignore this value when scale is unset
            scaling_mode=ScalingMode.NO_SCALING.value,
            is_2x=False,  # unused
            scale_dtype=jnp.float32,  # unused
            is_dbias=False,
            act_enum=act_type_id,
            act_len=act_len,
            is_outer=True,
1257
            act_params=act_params,
Alp Dener's avatar
Alp Dener committed
1258
1259
1260
1261
1262
1263
        )
        output = output.astype(x.dtype)
        dbias = None
        if is_dbias:
            dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)

1264
1265
1266
1267
        output = NoScaleTensor(
            data=output,
            amax=None,
        )
Alp Dener's avatar
Alp Dener committed
1268
        return output, dbias
1269

1270
1271
    # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
    if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
1272
        out = dact_lu(
1273
1274
1275
1276
1277
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type,
            quantizer=None,
            act_params=act_params,
1278
1279
        )
        return _quantize_dbias_impl(
1280
            out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
1281
        )
1282

1283
    is_gated = act_len == 2
1284
1285
1286
1287
1288
1289
1290
1291
1292
    # TE/common does not support DelayedScaling2x for gated-act yet
    if is_gated:
        war_output = try_apply_delayed_scaling_2x_war(
            f=quantize_dact_dbias,
            dz=dz,
            x=x,
            activation_type=activation_type,
            is_dbias=is_dbias,
            quantizer=quantizer,
1293
            flatten_axis=-2,
1294
            act_params=act_params,
1295
1296
1297
1298
        )
        if war_output is not None:
            return war_output

1299
1300
1301
    if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
        # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
        out = dact_lu(
1302
1303
            dz=dz,
            x=x,
1304
1305
            activation_type=activation_type,
            quantizer=None,
1306
            act_params=act_params,
1307
1308
        )
        out, dbias = _quantize_dbias_impl(
1309
            out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
1310
1311
1312
        )
        return out, dbias

Alp Dener's avatar
Alp Dener committed
1313
    if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
1314
1315
1316
1317
1318
        scale = quantizer.scale

    # TE/common dact_dbias_quantize does not support gated act yet
    if is_dbias and is_gated:
        dgated = dact_lu(
1319
1320
1321
1322
            dz.astype(jnp.float32),
            x.astype(jnp.float32),
            activation_type=activation_type,
            act_params=act_params,
1323
        )
1324
1325
1326
        out, dbias = _quantize_dbias_impl(
            dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
        )
1327
1328
1329
1330
1331
1332
1333
1334
1335
        return out, dbias

    (
        rowwise_casted_output,
        colwise_casted_output,
        rowwise_scale_inv,
        colwise_scale_inv,
        updated_amax,
        dbias,
1336
    ) = PrimitiveClass.outer_primitive.bind(
1337
1338
1339
1340
1341
1342
1343
1344
1345
        dz,
        x,
        scale,
        out_dtype=quantizer.q_dtype,
        scaling_mode=quantizer.scaling_mode.value,
        is_2x=quantizer.is_2x2x(),
        scale_dtype=quantizer.get_scale_dtype(),
        is_dbias=is_dbias,
        act_enum=act_type_id,
1346
        act_len=act_len,
1347
        is_outer=True,
1348
        act_params=act_params,
1349
    )
1350

1351
    # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
1352
    if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
        colwise_scale_inv = rowwise_scale_inv

    quantizer.update(updated_amax)

    out = ScaledTensorFactory.create(
        data=rowwise_casted_output,
        scale_inv=rowwise_scale_inv,
        colwise_data=colwise_casted_output,
        colwise_scale_inv=colwise_scale_inv,
        scaling_mode=quantizer.scaling_mode,
        dq_dtype=x.dtype,
1364
1365
1366
        q_layout=quantizer.q_layout,
        data_layout=quantizer.get_data_layout(),
        flatten_axis=-2,  # as output has act axis
1367
    )
1368

1369
    return out, dbias
1370
1371


1372
1373
def dact_lu(
    dz: jnp.ndarray,
1374
1375
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
1376
    quantizer: Optional[Quantizer] = None,
1377
    act_params: Optional[ActivationParams] = None,
1378
) -> Union[jnp.ndarray, ScaledTensor]:
1379
    """
1380
    Backward pass for activation with optional quantization.
1381

1382
1383
1384
1385
1386
1387
1388
1389
1390
    Args:
        dz: Gradient tensor from upstream.
        x: Input tensor that was used in forward pass.
        activation_type: Type of activation function that was applied.
        quantizer: Optional quantizer for FP8 quantization of the output gradient.

    Returns:
        The gradient of the activation with respect to the input.
    """
1391
    act_params = act_params if act_params is not None else ActivationParams()
1392
1393
1394
1395
1396
1397
    output, _ = quantize_dact_dbias(
        dz=dz,
        x=x,
        activation_type=activation_type,
        is_dbias=False,
        quantizer=quantizer,
1398
        act_params=act_params,
1399
    )
1400
    return output