cpp_extensions.py 184 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
#
# See LICENSE for license information.
"""JAX te custom call"""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
7
from typing import Tuple, Sequence, Union, Callable
8
9
from functools import partial, reduce
import operator
10
import os
11
12
import warnings

13
14
15
16
17
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.interpreters import xla, mlir
18
from jax.experimental.custom_partitioning import custom_partitioning
19
from jax.interpreters.mlir import ir, dtype_to_ir_type
20
21
from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
22
23
from jax._src import dispatch

24
25
26
27
28
29
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
30
from transformer_engine_jax import NVTE_Activation_Type
31

32
33
34
35
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec
36

37
38
39
40
41
42
43
try:
    from jaxlib.hlo_helpers import custom_call
except ImportError:
    # Newer JAX changed its API. But we want to support a few JAX
    # version, so we still need this import.
    pass

44
45
46
47
48
49
50
51
52
for _name, _value in transformer_engine_jax.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="CUDA")


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)
53
54
55
56
57
58
59
60
61

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
62
        TEDType.kByte: jnp.uint8
63
64
65
66
67
68
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)
69
70
71
72
73
74
75
76
77


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


78
79
80
81
82
83
84
def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


85
86
87
88
def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
89
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
90

91
92
93
94
95
96
97
98
    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
99
        jnp.uint8.dtype: TEDType.kByte,
100
    }
101

102
103
    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")
104

105
    return converter.get(jax_dtype)
106
107


108
109
110
111
112
113
114
115
def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)
116
117


118
def _check_valid_batch_dims(bdims):
119
    """
120
    Assert out non-supported bath dims
121
    """
122
123
124
125
    for dim in bdims:
        assert dim in [0, None], \
            "Currently only support batch_dim in [0, None], " \
            f"but got {dim=}"
126
127


128
ActivationEnum = {
129
130
131
132
133
134
135
136
137
138
    ('gelu',): NVTE_Activation_Type.GELU,
    ('gelu', 'linear'): NVTE_Activation_Type.GEGLU,
    ('silu',): NVTE_Activation_Type.SILU,
    ('silu', 'linear'): NVTE_Activation_Type.SWIGLU,
    ('relu',): NVTE_Activation_Type.RELU,
    ('relu', 'linear'): NVTE_Activation_Type.REGLU,
    ('quick_gelu',): NVTE_Activation_Type.QGELU,
    ('quick_gelu', 'linear'): NVTE_Activation_Type.QGEGLU,
    ('squared_relu',): NVTE_Activation_Type.SRELU,
    ('squared_relu', 'linear'): NVTE_Activation_Type.SREGLU,
139
140
141
}


142
143
class BasePrimitive(metaclass=ABCMeta):
    """
144
    jax primitive
145
146
147
148
149
150
151
152
153
154
    """

    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

155
156
157
158
159
160
161
    @classmethod
    def outer_abstract(cls, *args, **kwargs):
        """
        optional abstract wrapper to eliminate workspace tensors
        """
        return cls.abstract(*args, **kwargs)

162
163
164
165
166
167
168
169
    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

202
203
204
205
206

def register_primitive(cls):
    """
    register jax primitive
    """
207
208
209
210
211

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
212
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
213
214
215
216
217
218
219
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
220
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
221
222
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
223
    outer_p.def_abstract_eval(cls.outer_abstract)
224
225
226
227
228
229
230
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277


@dataclass
class CustomCallArgsWrapper:
    """
    wrapper of XLA custom call args
    """

    def __init__(self,
                 output_types,
                 operands,
                 operand_shapes,
                 operand_specific_layouts=None,
                 output_specific_layouts=None):
        self.output_types = output_types
        self.operands = operands
        self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
                                                                      operand_specific_layouts)
        output_shapes = [x.shape for x in output_types]
        self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
                                                                     output_specific_layouts)

    @staticmethod
    def generate_layouts(shapes, specific_layouts):
        """
        setup layouts for XLA custom call
        """

        def default_layout(shape):
            return range(len(shape) - 1, -1, -1)

        if specific_layouts is None:
            specific_layouts = {}

        layouts = []
        for idx, shape in enumerate(shapes):
            if idx in specific_layouts:
                layouts.append(specific_layouts[idx])
            else:
                layouts.append(default_layout(shape))
        return layouts


def custom_caller(name, args, opaque, has_side_effect, **kwargs):
    """
    XLA custom call warpper
    """
278
279
280
281
282
283
284
285
286
287
288
289
290
    if hasattr(mlir, "custom_call"):
        out = mlir.custom_call(name,
                               result_types=args.output_types,
                               operands=args.operands,
                               operand_layouts=args.operand_layouts,
                               result_layouts=args.output_layouts,
                               backend_config=opaque,
                               has_side_effect=has_side_effect,
                               **kwargs).results
    else:
        # Need to disable one pylint error as the second function
        # parameter name recenctly in JAX. Otherwise we won't be
        # compatible with multiple JAX version.
291
292
293
294
295
296
297
298
299
        out = custom_call(    # pylint: disable=too-many-function-args
            name,
            args.output_types,
            operands=args.operands,
            operand_layouts=args.operand_layouts,
            result_layouts=args.output_layouts,
            backend_config=opaque,
            has_side_effect=has_side_effect,
            **kwargs)
300
301
302
    return out


303
class LayerNormFwdPrimitive(BasePrimitive):
304
    """
305
    Layer Normalization Forward Primitive
306
    """
307
308
309
310
311
    name = "te_layernorm_forward"
    multiple_results = True
    impl_static_args = (3, 4)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
312
313

    @staticmethod
314
    def abstract(x_aval, gamma_aval, beta_aval, **kwargs):
315
        """
316
        LayerNorm fwd inner primitive abstract
317
        """
318
319
320
321
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        mu_rsigama_dtype = jnp.float32
322

323
324
        out_aval = core.raise_to_shaped(x_aval)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
325

326
327
328
329
        assert gamma_aval.size == beta_aval.size
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0

330
        wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
331
            x_aval.size // hidden_size,    # batch size
332
            hidden_size,
333
334
335
336
337
338
            jax_dtype_to_te_dtype(x_aval.dtype),    # in te_dtype
            jax_dtype_to_te_dtype(gamma_aval.dtype),    # weight te_dtype
            jax_dtype_to_te_dtype(x_aval.dtype),    # out te_dtype (same as input for Fp16/Bf16)
            True,
            kwargs['zero_centered_gamma'],
            kwargs['epsilon'])
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        wkspace_aval = out_aval.update(shape=wkspace_info[0],
                                       dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
        barrier_aval = out_aval.update(shape=barrier_info[0],
                                       dtype=te_dtype_to_jax_dtype(barrier_info[1]))

        return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        LayerNorm fwd outer primitive abstract
        """
        out_aval, mu_aval, rsigma_aval, _, _ = \
            LayerNormFwdPrimitive.abstract(*args, **kwargs)
353
        return out_aval, mu_aval, rsigma_aval
354
355

    @staticmethod
356
    def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
357
        """
358
        LayerNorm fwd lowering rules
359
        """
360
361
362
363
364
365
366
367
        x_aval, gamma_aval, beta_aval = ctx.avals_in
        assert gamma_aval.dtype == beta_aval.dtype
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
368

369
370
        assert g_type == b_type
        assert g_shape == b_shape
371

372
373
374
375
376
        # Output shape is same as the input shape, but the output type is same as the weight type.
        # See ln_api.cpp
        output_type = g_type.element_type
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
377

378
379
380
381
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
382

383
384
        wkspace_aval, barrier_aval = ctx.avals_out[-2:]

385
386
387
388
        out_types = [
            ir.RankedTensorType.get(out_shape, output_type),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
389
390
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
            ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
391
392
393
394
        ]
        operands = [x, gamma, beta]
        operand_shapes = [x_shape, g_shape, b_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
395

396
397
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

398
399
400
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
401
402
            wkspace_aval.size,
            barrier_aval.size,
403
404
            (0,),    # no dgamma_part in FWD pass
            (0,),    # no dbeta_part in BWD pass
405
406
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
407
408
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            jax_dtype_to_te_dtype(barrier_aval.dtype),
409
410
            TEDType.kByte,    # dummy dgamma_part te_dtype
            TEDType.kByte,    # dummy dbeta_part te_dtype
411
412
            zero_centered_gamma,
            epsilon,
413
            sm_margin,
414
        )
415

416
        out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
417

418
        return out
419
420

    @staticmethod
421
    def impl(x, gamma, beta, zero_centered_gamma, epsilon):
422
        """
423
        to describe implementation
424
        """
425
        assert LayerNormFwdPrimitive.inner_primitive is not None
426
        out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind(
427
428
429
430
431
            x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return out, mu, rsigma

    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
432
        """
433
        to describe batch rules for vmap
434
        """
435
436
437
438
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdPrimitive.outer_primitive is not None
        x, gamma, beta = batched_args
        x_bdim, _, _ = batch_dims
439

440
441
442
443
444
445
        out_bdims = x_bdim, x_bdim, x_bdim
        return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                          gamma,
                                                          beta,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
446

447
448
449
450
451
452
453
454
455
456
457
458
459
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, mu_sharding, rsigma_sharding)
460

461
462
463
464
465
466
467
468
469
470
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
471
472
473
474
475
476
477
478
479
480
481
        if g_spec[-1] is not None:
            warnings.warn(
                f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )
        if b_spec[-1] is not None:
            warnings.warn(
                f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )

482
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
483
484
        g_sharding = NamedSharding(mesh, PartitionSpec(None))
        b_sharding = NamedSharding(mesh, PartitionSpec(None))
485
486
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
487

488
489
490
491
492
493
        arg_shardings = (x_sharding, g_sharding, b_sharding)
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
        impl = partial(LayerNormFwdPrimitive.impl,
                       zero_centered_gamma=zero_centered_gamma,
                       epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
494
495


496
register_primitive(LayerNormFwdPrimitive)
497
498


499
500
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
                  epsilon: float):
501
    """
502
    Wrapper for TE layernorm fwd
503
    """
504
505
506
507
508
    return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                      gamma,
                                                      beta,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
509
510


511
class LayerNormBwdPrimitive(BasePrimitive):
512
    """
513
    Layer Normalization Backward Primitive
514
    """
515
516
517
518
519
    name = "te_layernorm_backward"
    multiple_results = True
    impl_static_args = (5, 6)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
520
521

    @staticmethod
522
    def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):
523
        """
524
        Layernorm bwd inner primitive abstract
525
        """
526
527
528
529
530
531
532
533
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
        assert mu_dtype == rsigma_dtype == jnp.float32
534

535
536
        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552

        wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = \
            transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
                x_aval.size // gamma_aval.size,           # batch size
                gamma_aval.size,                          # hidden size
                jax_dtype_to_te_dtype(x_aval.dtype),      # input te_dtype
                jax_dtype_to_te_dtype(gamma_aval.dtype),  # weight te_dtype
                True, kwargs['zero_centered_gamma'], kwargs['epsilon']
            )
        wkspace_aval = dx_aval.update(shape=wkspace_info[0],
                                      dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
        barrier_aval = dx_aval.update(shape=barrier_info[0],
                                      dtype=te_dtype_to_jax_dtype(barrier_info[1]))
        dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
                                              dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))
        dbeta_part_aval = dbeta_aval.update(shape=dbeta_part_info[0],
553
                                            dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]))
554
555
556
557
558
559
560
561
562
563
564

        return dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, barrier_aval, \
               dgamma_part_aval, dbeta_part_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        LayerNorm bwd outer primitive abstract
        """
        dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = \
            LayerNormBwdPrimitive.abstract(*args, **kwargs)
565
        return dx_aval, dgamma_aval, dbeta_aval
566
567

    @staticmethod
568
    def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
569
        """
570
        Layernorm bwd lowering rules
571
        """
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        _, x_aval, _, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(gamma.type)
        b_shape = b_type.shape
        assert g_type == b_type
        assert g_shape == b_shape

        dz_shape = ir.RankedTensorType(dz.type).shape
        mu_shape = ir.RankedTensorType(mu.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
588
589

        out_types = [
590
591
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
592
        ]
593

594
595
        operands = [dz, mu, rsigma, x, gamma]
        operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
596
597
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

598
599
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

600
        wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
601
602
603
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
604
605
            wkspace_aval.size,
            barrier_aval.size,
606
607
            dgamma_part_aval.shape,
            dbeta_part_aval.shape,
608
609
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
610
611
612
613
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            jax_dtype_to_te_dtype(barrier_aval.dtype),
            jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
            jax_dtype_to_te_dtype(dbeta_part_aval.dtype),
614
615
            zero_centered_gamma,
            epsilon,
616
            sm_margin,
617
        )
618

619
        out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
620

621
        return out
622

623
624
625
    @staticmethod
    def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
        assert LayerNormBwdPrimitive.inner_primitive is not None
626
        dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind(
627
628
            dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return dx, dgamma, dbeta
629

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert LayerNormBwdPrimitive.outer_primitive is not None
        dz, x, mu, rsigma, gamma = batched_args
        _, x_bdim, _, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim, gamma_bdim
        return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                          x,
                                                          mu,
                                                          rsigma,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
645

646
647
648
649
650
651
652
653
654
655
656
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
657
658
659
660
661
662
663
        if g_b_spec[-1] is not None:
            warnings.warn(
                f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
                f"of gamma and beta of Layernorm " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )

664
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
665
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
666
        return dx_sharding, dgamma_sharding, dbeta_sharding
667

668
669
670
671
672
673
674
675
676
677
678
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
679
680
681
682
683
684
685
        if g_b_spec[-1] is not None:
            warnings.warn(
                f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \
                f"of gamma and beta of Layernorm " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )

686
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
687
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None))
688
689
690
        out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
691
        arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None)))
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709

        def sharded_impl(dz, x, mu, rsigma, gamma):
            local_dx, local_dgamma, local_dbeta = \
                LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma,
                     zero_centered_gamma=zero_centered_gamma,
                     epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
            return local_dx, global_dgamma, global_dbeta

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(LayerNormBwdPrimitive)


def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray,
                  gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
710
    """
711
    Wrapper for TE layernorm bwd
712
    """
713
714
715
716
717
718
719
    return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                      x,
                                                      mu,
                                                      rsigma,
                                                      gamma,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
720
721


722
class RmsNormFwdPrimitive(BasePrimitive):
723
    """
724
    RMS Normalization Forward Primitive
725
    """
726
    name = "te_rmsnorm_forward"
727
    multiple_results = True
728
729
730
    impl_static_args = (2,)    # epsilon
    inner_primitive = None
    outer_primitive = None
731
732

    @staticmethod
733
    def abstract(x_aval, gamma_aval, **kwargs):
734
        """
735
        RMSNorm fwd inner primitive abstract
736
        """
737
738
739
740
741
742
743
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        rsigama_dtype = jnp.float32

        out_aval = core.raise_to_shaped(x_aval)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
744

745
746
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
747

748
        wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
749
            x_aval.size // hidden_size,    # batch size
750
            hidden_size,
751
752
753
754
755
756
            jax_dtype_to_te_dtype(x_aval.dtype),    # in te_dtype
            jax_dtype_to_te_dtype(gamma_aval.dtype),    # weight te_dtype
            jax_dtype_to_te_dtype(x_aval.dtype),    # out te_dtype (same as input for Fp16/Bf16)
            False,
            False,
            kwargs['epsilon'])
757
758
759
760
761
762
763
764
765
766
767
768
769
        wkspace_aval = out_aval.update(shape=wkspace_info[0],
                                       dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
        barrier_aval = out_aval.update(shape=barrier_info[0],
                                       dtype=te_dtype_to_jax_dtype(barrier_info[1]))

        return out_aval, rsigma_aval, wkspace_aval, barrier_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        RMSNorm fwd outer primitive abstract
        """
        out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs)
770
        return out_aval, rsigma_aval
771
772

    @staticmethod
773
    def lowering(ctx, x, gamma, *, epsilon):
774
        """
775
        RMSNorm fwd lowering rules
776
        """
777
778
779
780
781
782
783
784
785
786
787
        x_aval, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        rsigma_element_type = ir.F32Type.get()

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
788

789
790
        wkspace_aval, barrier_aval = ctx.avals_out[-2:]

791
        out_types = [
792
793
            ir.RankedTensorType.get(out_shape, x_type.element_type),
            ir.RankedTensorType.get(batch_shape, rsigma_element_type),
794
795
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
            ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
796
        ]
797
798
        operands = [x, gamma]
        operand_shapes = [x_shape, g_shape]
799
800
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

801
802
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

803
804
805
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
806
807
            wkspace_aval.size,
            barrier_aval.size,
808
809
            (0,),    # no dgamma_part in FWD pass
            (0,),    # no dbeta_part in BWD pass
810
811
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
812
813
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            jax_dtype_to_te_dtype(barrier_aval.dtype),
814
815
            TEDType.kByte,    # dummy dgamma_part te_dtype
            TEDType.kByte,    # dummy dbeta_part te_dtype
816
817
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
818
            sm_margin,
819
        )
820

821
        out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
822
823
824
825

        return out

    @staticmethod
826
    def impl(x, gamma, epsilon):
827
        """
828
        to describe implementation
829
        """
830
        assert RmsNormFwdPrimitive.inner_primitive is not None
831
        out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
832
        return out, rsigma
833
834

    @staticmethod
835
    def batcher(batched_args, batch_dims, *, epsilon):
836
        """
837
        to describe batch rules for vmap
838
        """
839
840
841
842
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdPrimitive.outer_primitive is not None
        x, gamma = batched_args
        x_bdim, _ = batch_dims
843

844
845
        out_bdims = x_bdim, x_bdim
        return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
846

847
848
849
850
851
852
853
854
855
856
857
858
859
    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, rsigma_sharding)
860

861
862
863
864
865
866
867
868
869
870
    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
871
872
873
874
875
876
        if g_spec[-1] is not None:
            warnings.warn(
                f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )

877
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
878
        g_sharding = NamedSharding(mesh, PartitionSpec(None))
879
880
881
882
883
884
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (x_sharding, g_sharding)
        out_shardings = (out_sharding, rsigma_sharding)
        impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
885
886


887
register_primitive(RmsNormFwdPrimitive)
888
889


890
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
891
    """
892
    Wrapper for TE rmsnorm fwd
893
    """
894
    return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
895
896


897
class RmsNormBwdPrimitive(BasePrimitive):
898
    """
899
    RMS Normalization Backward Primitive
900
    """
901
    name = "te_rmsnorm_backward"
902
    multiple_results = True
903
904
905
    impl_static_args = (4,)    # epsilon
    inner_primitive = None
    outer_primitive = None
906
907

    @staticmethod
908
    def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs):
909
        """
910
        RMSNorm bwd inner primitive abstract
911
        """
912
913
914
915
916
917
918
919
920
921
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert rsigma_aval.shape == x_aval.shape[:-1]
        assert rsigma_dtype == jnp.float32

        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = core.raise_to_shaped(gamma_aval)
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945

        wkspace_info, barrier_info, dgamma_part_info, _ = \
            transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
                x_aval.size // gamma_aval.size,           # batch size
                gamma_aval.size,                          # hidden size
                jax_dtype_to_te_dtype(x_aval.dtype),      # in te_dtype
                jax_dtype_to_te_dtype(gamma_aval.dtype),  # weight te_dtype
                False, False, kwargs['epsilon']
            )
        wkspace_aval = dx_aval.update(shape=wkspace_info[0],
                                      dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
        barrier_aval = dx_aval.update(shape=barrier_info[0],
                                      dtype=te_dtype_to_jax_dtype(barrier_info[1]))
        dgamma_part_aval = dgamma_aval.update(shape=dgamma_part_info[0],
                                              dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]))

        return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        RMSNorm bwd outer primitive abstract
        """
        dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs)
946
947
948
949
        return dx_aval, dgamma_aval

    @staticmethod
    def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
950
        """
951
        RMSNorm bwd lowering rules
952
        """
953
954
955
956
957
958
959
960
961
962
        _, x_aval, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        dz_shape = ir.RankedTensorType(dz.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
963

964
965
        wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:]

966
        out_types = [
967
968
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
969
970
971
972
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
            ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
            ir.RankedTensorType.get(dgamma_part_aval.shape,
                                    jax_dtype_to_ir_dtype(dgamma_part_aval.dtype))
973
        ]
974
975
        operands = [dz, rsigma, x, gamma]
        operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
976
977
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

978
979
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

980
981
982
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
983
984
            wkspace_aval.size,
            barrier_aval.size,
985
986
            dgamma_part_aval.shape,
            (0,),    # no dbeta_part for RMSnorm
987
988
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
989
990
991
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            jax_dtype_to_te_dtype(barrier_aval.dtype),
            jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
992
            TEDType.kByte,    # dummy dbeta_part te_dtype
993
994
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
995
            sm_margin,
996
        )
997

998
        out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
999
1000
1001

        return out

1002
1003
1004
    @staticmethod
    def impl(dz, x, rsigma, gamma, epsilon):
        assert RmsNormBwdPrimitive.inner_primitive is not None
1005
1006
        dx, dgamma, _, _, _ = \
            RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        return dx, dgamma

    @staticmethod
    def batcher(batched_args, batch_dims, *, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert RmsNormBwdPrimitive.outer_primitive is not None
        dz, x, rsigma, gamma = batched_args
        _, x_bdim, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim
        return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma,
                                                        epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
1031
1032
1033
1034
1035
        if g_spec[-1] is not None:
            warnings.warn(
                f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )
1036
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
1037
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        return dx_sharding, dgamma_sharding

    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
1051
1052
1053
1054
1055
        if g_spec[-1] is not None:
            warnings.warn(
                f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )
1056
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
1057
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(None))
1058
1059
1060
        out_shardings = dx_sharding, dgamma_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
1061
        arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None)))
1062
1063
1064
1065
1066
1067
1068
1069
1070

        def sharded_impl(dz, x, rsigma, gamma):
            local_dx, local_dgamma = \
                RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            return local_dx, global_dgamma

        return mesh, sharded_impl, out_shardings, arg_shardings

1071

1072
register_primitive(RmsNormBwdPrimitive)
1073
1074


1075
1076
def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray,
                epsilon: float):
1077
    """
1078
    Wrapper for TE layernorm bwd
1079
    """
1080
    return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
1081
1082


1083
class SoftmaxPrimitive(BasePrimitive):
1084
    """
1085
    Softmax Primitive
1086
    """
1087
    max_k_seqlen_supported = 16384
1088
    name = "te_softmax_internal_placeholder"
1089
1090

    @staticmethod
1091
1092
1093
1094
1095
    @abstractmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError
1096

1097
1098
1099
1100
1101
    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
        threads_per_block = 128    # Depends on the kernel implmentation
1102

1103
1104
1105
1106
1107
1108
        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block
1109
1110

    @staticmethod
1111
    def forward_abstract(logits_aval, scale_factor):
1112
        """
1113
        softmax_forward abstract
1114
        """
1115
1116
1117
1118
1119
1120
1121
1122
1123
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
1124

1125
1126
        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval
1127

1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
        i_aval, = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        pad_batch = batch
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits]
        operand_shapes = [i_shape]
1146
1147
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1148
1149
1150
1151
        opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
                                                                k_seqlen,
                                                                jax_dtype_to_te_dtype(i_aval.dtype),
                                                                scale_factor)
1152

1153
        out = custom_caller(name, args, opaque, False)
1154
1155
1156

        return [out]

1157
1158
1159
1160
1161
1162
1163
1164
    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output
1165

1166
1167
1168
1169
1170
1171
1172
1173
    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
        logits, = batched_args
        logits_bdim, = batch_dims
1174

1175
1176
        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims
1177

1178
1179
    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
1180
1181
1182
1183
1184
        """
        softmax_forward infer_sharding_from_operands
        """
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
1185
1186
1187
1188
1189
1190
1191
        if logits_spec[-1] is not None:
            warnings.warn(
                f"Sharding the hidden dimension is not supported in {cls.name}! " \
                f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
                f"collective ops and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
1192
        return out_sharding
1193

1194
1195
    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
1196
        """
1197
        softmax_forward partitioning
1198
        """
1199
        del result_infos
1200
1201
1202
1203
1204
1205
1206
1207
1208
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
                f"Sharding the hidden dimension is not supported in {cls.name}! " \
                f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
                f"collective ops and hurt performance."
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
1209
1210
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
1211

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
    @staticmethod
    def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None):    # pylint: disable=unused-argument
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]
1222

1223
        assert dz_aval.shape == softmax_out_aval.shape
1224

1225
        dx_aval = core.raise_to_shaped(dz_aval)
1226
        return dx_aval
1227
1228

    @staticmethod
1229
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
1230
        """
1231
        softmax_backward lowering rules
1232
        """
1233
        dz_aval, _ = ctx.avals_in
1234

1235
1236
        dz_type = ir.RankedTensorType(dz.type)
        dz_shape = dz_type.shape
1237

1238
1239
1240
1241
1242
1243
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, dz_shape[:-3])
        pad_batch = batch    # unused
        heads = dz_shape[-3]
        q_seqlen = dz_shape[-2]
        k_seqlen = dz_shape[-1]
1244

1245
1246
        softmax_out_type = ir.RankedTensorType(softmax_out.type)
        softmax_out_shape = softmax_out_type.shape
1247

1248
        out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
1249
1250
        operands = [dz, softmax_out]
        operand_shapes = [dz_shape, softmax_out_shape]
1251
1252
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1253
1254
1255
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype),
            scale_factor)
1256

1257
        out = custom_caller(name, args, opaque, False)
1258

1259
        return [out]
1260
1261

    @staticmethod
1262
    def backward_impl(primitive, dz, softmax_out, scale_factor):
1263
        """
1264
        softmax_backward implementation
1265
        """
1266
1267
1268
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx
1269

1270
1271
1272
1273
1274
1275
1276
1277
    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims
1278

1279
1280
        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
1281

1282
1283
    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
1284
        """
1285
        softmax_backward infer_sharding_from_operands
1286
        """
1287
        del scale_factor, result_infos    # Unused.
1288
1289
1290
1291
1292
1293
1294
1295
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
                f"Sharding the hidden dimension is not supported in {cls.name}! " \
                f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
                f"collective ops and hurt performance."
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
1296
        return dx_sharding
1297

1298
1299
    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
1300
1301
1302
1303
        """
        softmax_backward partition
        """
        del result_infos
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
                f"Sharding the hidden dimension is not supported in {cls.name}! " \
                f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
                f"collective ops and hurt performance."
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

1320
1321
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
1322
1323


1324
1325
1326
1327
1328
1329
1330
1331
1332
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
    name = "te_scaled_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None
1333

1334
1335
1336
1337
1338
    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads
1339

1340
1341
        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
1342
                and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
1343
1344
1345
1346
1347
1348
1349
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False
1350

1351
1352
1353
1354
1355
1356
    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)
1357

1358
1359
1360
1361
1362
1363
1364
1365
1366
    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)
1367

1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits,
                                             scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive,
                                                batched_args,
                                                batch_dims,
                                                scale_factor=scale_factor)
1380

1381
1382
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1383
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
1384
            scale_factor, mesh, arg_infos, result_infos)
1385
1386
1387

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
1388
1389
1390
        return ScaledSoftmaxFwdPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl,
                                                           scale_factor, mesh, arg_infos,
                                                           result_infos)
1391
1392


1393
register_primitive(ScaledSoftmaxFwdPrimitive)
1394

1395
1396

def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
1397
    """
1398
1399
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
1400
    """
1401
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
1402
1403


1404
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
1405
    """
1406
    Scaled Softmax Bwd Primitive
1407
    """
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
    name = "te_scaled_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)
1420
1421

    @staticmethod
1422
    def abstract(dz_aval, softmax_out_aval, scale_factor):
1423
        """
1424
        te_scaled_softmax_backward abstract
1425
        """
1426
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)
1427

1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1438

1439
        return out
1440

1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1458
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
1459
            scale_factor, mesh, arg_infos, result_infos)
1460
1461

    @staticmethod
1462
    def partition(scale_factor, mesh, arg_infos, result_infos):
1463
1464
1465
        return ScaledSoftmaxBwdPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl,
                                                            scale_factor, mesh, arg_infos,
                                                            result_infos)
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499


register_primitive(ScaledSoftmaxBwdPrimitive)


def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                       scale_factor: float) -> jnp.ndarray:
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                          softmax_out,
                                                          scale_factor=scale_factor)


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
1500
                and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, mask_aval, scale_factor):    # pylint: disable=unused-argument
1511
        """
1512
        te_scaled_masked_softmax_forward abstract
1513
1514
        """

1515
1516
1517
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
1518

1519
1520
1521
1522
1523
1524
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
1525

1526
1527
1528
        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
1529
        ]
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
        assert pad_batch in (1, batch)    # 1 means broadcast
        assert mask_shape[-3] == 1    # 1 means broadcast
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """

        logits_aval, _ = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        mask_type = ir.RankedTensorType(mask.type)
        mask_shape = mask_type.shape
        pad_batch = reduce(operator.mul, mask_shape[:-3])

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits, mask]
        operand_shapes = [i_shape, mask_shape]
1562
1563
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1564
1565
1566
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype),
            scale_factor)
1567

1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
        out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

        return [out]

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits,
                                                                      mask,
                                                                      scale_factor=scale_factor)
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
        return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
            logits, mask, scale_factor=scale_factor), out_bdims

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1593
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
1594
            scale_factor, mesh, arg_infos, result_infos)
1595
1596
1597

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
1598
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
1599
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos)
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits,
                                                                mask,
                                                                scale_factor=scale_factor)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1670
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
1671
            scale_factor, mesh, arg_infos, result_infos)
1672
1673
1674

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
1675
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
1676
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos)
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                                softmax_out,
                                                                scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
1711
                and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
1712
1713
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
1714
                and k_seqlen == q_seqlen):
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
1725
1726
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1756
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
1757
            scale_factor, mesh, arg_infos, result_infos)
1758
1759
1760

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
1761
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
1762
1763
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos,
            result_infos)
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1812
1813
1814

        return out

1815
1816
1817
1818
1819
1820
1821
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor)
1822

1823
1824
1825
1826
1827
1828
1829
1830
    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)
1831

1832
1833
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1834
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
1835
            scale_factor, mesh, arg_infos, result_infos)
1836

1837
1838
    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
1839
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
1840
1841
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos,
            result_infos)
1842
1843
1844
1845
1846
1847
1848


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                                           scale_factor: float) -> jnp.ndarray:
1849
    """
1850
1851
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
1852
    """
1853
1854
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor)
1855
1856


1857
1858
@dataclass(frozen=True)
class FusedAttnHelper:
1859
    """
1860
    Helper for the fused attention backend
1861
    """
1862

1863
1864
    q_dtype: jnp.dtype
    kv_dtype: jnp.dtype
1865
1866
1867
1868
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
1869
1870
1871
1872
    q_num_heads: int
    kv_num_heads: int
    q_max_seqlen: int
    kv_max_seqlen: int
1873
1874
1875
1876
1877
1878
1879
1880
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
zlsh80826's avatar
zlsh80826 committed
1881
        return transformer_engine_jax.get_fused_attn_backend(
1882
            jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype),
zlsh80826's avatar
zlsh80826 committed
1883
            self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability,
1884
            self.q_num_heads, self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen,
zlsh80826's avatar
zlsh80826 committed
1885
            self.head_dim)
1886

1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
    @staticmethod
    def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
        """Parse qkv aval"""
        match qkv_layout:
            case NVTE_QKV_Layout.NVTE_BS3HD:
                *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
                kv_batch_shape = q_batch_shape
                kv_max_seqlen = q_max_seqlen
                num_gqa_groups = attn_heads
                kv_head_dim = q_head_dim
                assert nqkv == 3
            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
                *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
                *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape
                assert nkv == 2
            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
                *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
                *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape
                assert k_aval.shape == v_aval.shape
            case _:
                raise ValueError(f"Unexpected {qkv_layout=}")
        assert q_batch_shape == kv_batch_shape
        assert q_head_dim == kv_head_dim
        assert q_aval.dtype == k_aval.dtype == v_aval.dtype

        return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim)

1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954

@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
                f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


zlsh80826's avatar
zlsh80826 committed
1955
def generate_cu_seqlen(actual_seqlen):
1956
1957
1958
    """
    Generating cumsum seqlen for a batch
    """
zlsh80826's avatar
zlsh80826 committed
1959
    cu_seqlen = jnp.cumsum(actual_seqlen)
1960
1961
1962
1963
    cu_seqlen = jnp.hstack((0, cu_seqlen))
    return cu_seqlen


1964
class FusedAttnFwdPrimitive(BasePrimitive):
1965
    """
1966
    Fused Attention Forward Primitive
1967
    """
1968
    name = "te_fused_attn_forward"
1969
    multiple_results = True
1970
    impl_static_args = (7, 8, 9, 10, 11, 12)
1971
1972
    inner_primitive = None
    outer_primitive = None
1973
1974

    @staticmethod
1975
1976
1977
    def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval,
                 kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type,
                 qkv_layout, scaling_factor, dropout_probability, is_training):
1978
        """
1979
        Fused attention fwd abstract
1980
        """
1981
1982
1983
1984
1985
1986
1987
1988
1989
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype
        assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype

        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
1990

1991
1992
        output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim)
        out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
1993

1994
        # backend determines the softmax buffer shape/dtype
1995
1996
1997
        backend = FusedAttnHelper(q_dtype, k_dtype, qkv_layout, attn_bias_type, attn_mask_type,
                                  dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
                                  kv_max_seqlen, head_dim).get_fused_attn_backend()
1998

1999
        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
2000
2001
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
            softmax_dtype = q_dtype
2002
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
2003
            softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
2004
2005
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
2006
            raise ValueError(f'Unsupported {backend=}')
2007
        softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
2008

2009
2010
        # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
        # 32-bit unsigned int to get the buffer size we need in the C++ kernel
2011
2012
2013
2014
        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
2015
2016
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)

2017
2018
2019
2020
2021
2022
        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

2023
2024
        # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
        # prepare for the active fused-attn backend
2025
2026
2027
2028
2029
2030
2031
        input_batch = reduce(operator.mul, batch_shape)
        wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
            input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
            bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type,
            attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training)
        wkspace_aval = q_aval.update(shape=wkspace_info[0],
                                     dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
2032
2033

        return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
2034

2035
2036
2037
    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
2038
        Fused attention fwd outer primitive abstract
2039
2040
        """
        out_aval, softmax_aux_aval, rng_state_aval, _ = \
2041
            FusedAttnFwdPrimitive.abstract(*args, **kwargs)
2042
        return out_aval, softmax_aux_aval, rng_state_aval
2043
2044

    @staticmethod
2045
2046
    def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
                 attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training):
2047
        """
2048
        Fused attention fwd lowering rules
2049
        """
2050
        operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed]
2051
        operand_shapes = map(lambda x: x.type.shape, operands)
2052
        out_types = [
2053
2054
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2055
2056
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
2057

2058
2059
2060
2061
2062
2063
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)

        input_batch = reduce(operator.mul, batch_shape)
2064
2065
2066
2067
2068
2069

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)
2070
2071
2072

        wkspace_aval = ctx.avals_out[-1]

2073
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
2074
2075
2076
2077
            input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
            bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training)
2078

2079
        out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
2080

2081
2082
2083
        return out

    @staticmethod
2084
2085
2086
    def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, qkv_layout,
             scaling_factor, dropout_probability, is_training):
        assert FusedAttnFwdPrimitive.inner_primitive is not None
2087

2088
2089
        q_cu_seqlen = generate_cu_seqlen(q_seqlen)
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
2090

2091
2092
2093
2094
        output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
2095
            bias,
2096
2097
            q_cu_seqlen,
            kv_cu_seqlen,
2098
2099
2100
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
2101
            qkv_layout=qkv_layout,
2102
2103
2104
2105
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return output, softmax_aux, rng_state
2106

2107
    @staticmethod
2108
2109
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout,
                scaling_factor, dropout_probability, is_training):
2110
        _check_valid_batch_dims(batch_dims)
2111
2112
        assert FusedAttnFwdPrimitive.outer_primitive is not None
        q_bdim, *_, seed_bdim = batch_dims
2113

2114
2115
2116
2117
2118
2119
2120
2121
        out_bdims = q_bdim, q_bdim, seed_bdim
        return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          qkv_layout=qkv_layout,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training), out_bdims
2122

2123
    @staticmethod
2124
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
2125
2126
2127
2128
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        match qkv_layout:
            case NVTE_QKV_Layout.NVTE_BS3HD:
                # q_spec = (...batch, q_seqlen, head, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None))
            case NVTE_QKV_Layout.NVTE_BSHD_BS2HD:
                # q_spec = (...batch, q_seqlen, head, hidden)
                # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4]))
            case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD:
                # q_spec = (...batch, q_seqlen, head, hidden)
                # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
                out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
                softmax_aux_sharding = NamedSharding(
                    mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3]))
            case _:
                raise ValueError(f"Unsupported {qkv_layout=}")
2151
2152
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
2153

2154
    @staticmethod
2155
2156
2157
2158
2159
2160
2161
    def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
                  is_training, mesh, arg_infos, result_infos):
        out_sharding = result_infos[0].sharding
        softmax_aux_sharding = result_infos[1].sharding
        rng_state_sharding = seed_sharding = NamedSharding(mesh,
                                                           PartitionSpec(get_all_mesh_axes(), None))
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
2162
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
2163
        impl = partial(FusedAttnFwdPrimitive.impl,
2164
2165
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
2166
                       qkv_layout=qkv_layout,
2167
2168
2169
2170
2171
2172
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


2173
register_primitive(FusedAttnFwdPrimitive)
2174
2175


2176
class FusedAttnBwdPrimitive(BasePrimitive):
2177
    """
2178
    Fused Attention Backward Primitive
2179
    """
2180
    name = "te_fused_attn_backward"
2181
    multiple_results = True
2182
    impl_static_args = (10, 11, 12, 13, 14, 15)
2183
2184
    inner_primitive = None
    outer_primitive = None
2185
2186

    @staticmethod
2187
2188
2189
    def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
                 doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
                 attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training):
2190
        """
2191
        Fused attention bwd abstract
2192
        """
2193
        del softmax_aux_aval, rng_state_aval, output_aval
2194

2195
2196
2197
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
        v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
2198
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
2199
2200
2201
2202
2203
2204
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
        assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype

        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)
2205

2206
2207
2208
2209
2210
2211
        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)

2212
        input_batch = reduce(operator.mul, batch_shape)
2213
        wkspace_shape, wkspace_dtype = \
2214
2215
2216
2217
            transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
                input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
                bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type,
                attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2218

2219
2220
2221
        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
        dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
2222
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
2223
2224
        wkspace_aval = q_aval.update(shape=wkspace_shape,
                                     dtype=te_dtype_to_jax_dtype(wkspace_dtype))
2225

2226
        return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
2227
2228
2229
2230

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
2231
        Fused attention fwd outer primitive abstract
2232
        """
2233
2234
2235
        dq_aval, dk_aval, dv_aval, dbias_aval, _ = \
            FusedAttnBwdPrimitive.abstract(*args, **kwargs)
        return dq_aval, dk_aval, dv_aval, dbias_aval
2236
2237

    @staticmethod
2238
2239
2240
    def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                 kv_cu_seqlen, *, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
                 dropout_probability, is_training):
2241
        """
2242
        Fused attention bwd lowering rules
2243
        """
2244
2245
2246
        operands = [
            q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen
        ]
2247
        operand_shapes = map(lambda x: x.type.shape, operands)
2248
        out_types = [
2249
2250
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2251
        ]
2252

2253
2254
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2255
2256
2257
2258
2259
2260
        q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in

        batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \
            FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout)

        input_batch = reduce(operator.mul, batch_shape)
2261
2262
2263
2264
2265
2266

        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_batch = bias_heads = 0
        else:
            *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
            bias_batch = reduce(operator.mul, bias_batch_shape)
2267
2268
2269

        wkspace_aval = ctx.avals_out[-1]

2270
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
2271
2272
2273
2274
            input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups,
            bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training)
2275

2276
        out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
2277
2278
2279

        return out

2280
    @staticmethod
2281
2282
2283
2284
    def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen,
             attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
             is_training):
        assert FusedAttnBwdPrimitive.inner_primitive is not None
2285

2286
2287
        q_cu_seqlen = generate_cu_seqlen(q_seqlen)
        kv_cu_seqlen = generate_cu_seqlen(kv_seqlen)
2288

2289
2290
2291
2292
        dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            k,
            v,
2293
            bias,
2294
2295
2296
2297
            softmax_aux,
            rng_state,
            output,
            doutput,
2298
2299
            q_cu_seqlen,
            kv_cu_seqlen,
2300
2301
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
2302
            qkv_layout=qkv_layout,
2303
2304
2305
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2306
        return dq, dk, dv, dbias
2307

2308
    @staticmethod
2309
2310
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout,
                scaling_factor, dropout_probability, is_training):
2311
        _check_valid_batch_dims(batch_dims)
2312
2313
        assert FusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, k_bdim, v_bdim, *_ = batch_dims
2314

2315
2316
2317
2318
2319
2320
2321
2322
        out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
        return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          qkv_layout=qkv_layout,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training), out_bdims
2323

2324
    @staticmethod
2325
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor,
2326
2327
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
2328
2329
2330
2331
2332
2333
2334
2335
2336
        del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
2337
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2338
        return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
2339
2340

    @staticmethod
2341
2342
    def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability,
                  is_training, mesh, arg_infos, result_infos):
2343
        del result_infos
2344
2345
2346
2347
2348
2349
2350
        q_spec = get_padded_spec(arg_infos[0])
        k_spec = get_padded_spec(arg_infos[1])
        v_spec = get_padded_spec(arg_infos[2])
        bias_spec = get_padded_spec(arg_infos[3])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
        dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
2351
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2352
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
2353
        out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
2354

2355
2356
2357
2358
2359
2360
        def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                         kv_cu_seqlen):
            local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
                q,
                k,
                v,
2361
                bias,
2362
2363
2364
2365
                softmax_aux,
                rng_state,
                output,
                doutput,
2366
2367
                q_cu_seqlen,
                kv_cu_seqlen,
2368
2369
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
2370
                qkv_layout=qkv_layout,
2371
2372
2373
2374
2375
2376
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
2377
            return local_dq, local_dk, local_dv, global_dbias
2378
2379
2380
2381

        return mesh, sharded_impl, out_shardings, arg_shardings


2382
register_primitive(FusedAttnBwdPrimitive)
2383
2384


2385
2386
2387
2388
def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray,
                             seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                             attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                             dropout_probability: float, is_training: bool):
2389
    """
2390
2391
    Wrapper for TE self fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2392
    """
2393
2394
2395
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

2396
2397
2398
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
2399

2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
    _not_used = jnp.zeros(0, qkv.dtype)
    return FusedAttnFwdPrimitive.outer_primitive.bind(qkv,
                                                      _not_used,
                                                      _not_used,
                                                      bias,
                                                      seqlen,
                                                      seqlen,
                                                      seed,
                                                      attn_bias_type=attn_bias_type,
                                                      attn_mask_type=attn_mask_type,
                                                      qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
                                                      scaling_factor=scaling_factor,
                                                      dropout_probability=dropout_probability,
                                                      is_training=is_training)
2414
2415


2416
2417
2418
2419
2420
2421
2422
2423
def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
                             rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
                             seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                             attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                             dropout_probability: float, is_training: bool):
    """
    Wrapper for TE self fused attention bwd
    Return the gradients of self fused attention with packed qkv input
2424
    """
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
    dummy_input = jnp.zeros(0, dtype=qkv.dtype)
    dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
        qkv,
        dummy_input,
        dummy_input,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        seqlen,
        seqlen,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
    return dqkv, dbias


def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                            q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray,
                            attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
                            scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Wrapper for TE fused attention fwd with kvpacked inputs
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2456
    """
2457
2458
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)
2459

2460
2461
2462
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
2463

2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
    return FusedAttnFwdPrimitive.outer_primitive.bind(q,
                                                      kv,
                                                      jnp.zeros(0, q.dtype),
                                                      bias,
                                                      q_seqlen,
                                                      kv_seqlen,
                                                      seed,
                                                      attn_bias_type=attn_bias_type,
                                                      attn_mask_type=attn_mask_type,
                                                      qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
                                                      scaling_factor=scaling_factor,
                                                      dropout_probability=dropout_probability,
                                                      is_training=is_training)
2477

2478

2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                            softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
                            doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
                            attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
                            scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Wrapper for TE fused attention bwd with kvpacked inputs
    Return the gradients of fused attention with packed kv input
    """
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
    dummy_input = jnp.zeros(0, q.dtype)
    dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind(
        q,
        kv,
        dummy_input,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
    return dq, dkv, dbias
2510
2511
2512
2513
2514
2515
2516
2517


def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
                   q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray,
                   attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
                   scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Wrapper for TE fused attention fwd, where query, key, value are seperated tensors
2518
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2519
2520
2521
2522
2523
2524
2525
2526
    """
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)

2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
    return FusedAttnFwdPrimitive.outer_primitive.bind(
        q,
        k,
        v,
        bias,
        q_seqlen,
        kv_seqlen,
        seed,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554


def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray,
                   softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
                   doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray,
                   attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
                   scaling_factor: float, dropout_probability: float, is_training: bool):
    """
    Wrapper for TE fused attention bwd
    Return the gradients of fused attention with seperated query, key, value tensors
    """
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
    return FusedAttnBwdPrimitive.outer_primitive.bind(
        q,
        k,
        v,
        bias,
        softmax_aux,
        rng_state,
        output,
        doutput,
        q_seqlen,
        kv_seqlen,
        attn_bias_type=attn_bias_type,
        attn_mask_type=attn_mask_type,
        qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD,
        scaling_factor=scaling_factor,
        dropout_probability=dropout_probability,
        is_training=is_training)
2572
2573


2574
class ActLuPrimitive(BasePrimitive):
2575
    """
2576
    Activation Forward Primitive
2577
    """
2578
    name = "te_act_lu"
2579
2580
2581
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
2582
    impl_static_args = (1,)
2583
2584

    @staticmethod
2585
    def abstract(x_aval, *, act_enum):  # pylint: disable=unused-argument
2586
        """
2587
        act_lu abstract
2588
2589
2590
2591
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

2592
        x_shape = x_aval.shape
2593
        assert (x_shape[-2] == 2 or x_shape[-2] == 1)
2594
2595
2596
2597
2598
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)
2599

2600
        return out_aval
2601
2602

    @staticmethod
2603
    def lowering(ctx, x, *, act_enum):
2604
        """
2605
        act_lu lowering rules
2606
        """
2607
2608
2609
2610
2611
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
2612

2613
2614
2615
2616
2617
2618
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
        ]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
2619

2620
2621
2622
        hidden_size = ir_x_shape[-1]
        batch_size = reduce(operator.mul, ir_x_shape[:-2])
        in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
2623
2624
        opaque = transformer_engine_jax.pack_common_descriptor(
            (batch_size, hidden_size), in_dtype, in_dtype, act_enum)
2625

2626
        out = custom_caller(ActLuPrimitive.name, args, opaque, False)
2627

2628
        return [out]
2629

2630
    @staticmethod
2631
2632
2633
    def impl(x, act_enum):
        assert ActLuPrimitive.inner_primitive is not None
        out = ActLuPrimitive.inner_primitive.bind(x, act_enum=act_enum)
2634
        return out
2635

2636
    @staticmethod
2637
    def batcher(batched_args, batch_dims, *, act_enum):
2638
        """
2639
        act_lu batcher
2640
2641
        """
        _check_valid_batch_dims(batch_dims)
2642
        assert ActLuPrimitive.outer_primitive is not None
2643
2644
        inputs, = batched_args
        inputs_bdim, = batch_dims
2645

2646
        out_bdims = inputs_bdim
2647
        return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_enum), out_bdims
2648

2649
    @staticmethod
2650
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
2651
        """
2652
        act_lu infer_sharding_from_operands
2653
        """
2654
        del result_infos, act_enum    # Unused.
2655
2656
2657
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding
2658

2659
    @staticmethod
2660
    def partition(act_enum, mesh, arg_infos, result_infos):
2661
        """
2662
        act_lu partitioning
2663
        """
2664
        del result_infos, act_enum
2665
2666
2667
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
2668
2669
2670
2671
2672

        def sharded_impl(x):
            return ActLuPrimitive.impl(x, act_enum=act_enum)

        return mesh, sharded_impl, out_sharding, arg_shardings
2673
2674


2675
register_primitive(ActLuPrimitive)
2676

2677

2678
def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
2679
    """
2680
2681
2682
2683
    act_lu wrapper
    Return act_lu(inputs)
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
2684
    """
2685
2686
    act_type_id = ActivationEnum[activation_type]
    return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
2687
2688


2689
class DActLuPrimitive(BasePrimitive):
2690
    """
2691
    Dgated ActLu Primitive
2692
    """
2693
    name = "te_dact_lu"
2694
2695
2696
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
2697
    impl_static_args = (2,)
2698
2699

    @staticmethod
2700
    def abstract(dz_aval, x_aval, *, act_enum):  # pylint: disable=unused-argument
2701
        """
2702
        dact_lu abstract
2703
        """
2704
2705
2706
2707
2708
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
2709
        assert (x_aval.shape[-2] == 2 or x_aval.shape[-2] == 1)
2710

2711
2712
2713
2714
        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)
2715

2716
        return out_aval
2717
2718

    @staticmethod
2719
    def lowering(ctx, dz, x, *, act_enum):
2720
        """
2721
        dact_lu lowering rules
2722
        """
2723
2724
2725
2726
2727
2728
2729
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
        ir_in_type = ir.RankedTensorType(dz.type)
        ir_in_shape = ir_in_type.shape
        gi_type = ir.RankedTensorType(x.type)
        gi_shape = gi_type.shape
2730
#        assert ir_in_shape == gi_shape
2731
2732
        for axis in range(len(ir_in_shape) - 1):
            assert ir_in_shape[axis] == gi_shape[axis]
2733

2734
2735
2736
2737
2738
2739
        ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
        i_hidden_size = ir_in_shape[-1]
        g_hidden_size = gi_shape[-1]
        assert i_hidden_size == g_hidden_size
        out_dtype = ir_in_type.element_type
        out_shape = gi_shape
2740
2741

        out_types = [
2742
            ir.RankedTensorType.get(out_shape, out_dtype),
2743
        ]
2744
2745
        operands = [dz, x]
        operand_shapes = [ir_in_shape, gi_shape]
2746
2747
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2748
2749
        in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
2750
                                                               in_dtype, in_dtype, act_enum)
2751

2752
        out = custom_caller(DActLuPrimitive.name, args, opaque, False)
2753
2754
2755
2756

        return [out]

    @staticmethod
2757
    def impl(dz, x, act_enum):
2758
        """
2759
        dact_lu implementation
2760
        """
2761
2762
        assert DActLuPrimitive.inner_primitive is not None
        dx = DActLuPrimitive.inner_primitive.bind(dz, x, act_enum=act_enum)
2763
        return dx
2764
2765

    @staticmethod
2766
    def batcher(batched_args, batch_dims, *, act_enum):
2767
        """
2768
        dact_lu batcher
2769
        """
2770
        _check_valid_batch_dims(batch_dims)
2771
        assert DActLuPrimitive.outer_primitive is not None
2772
2773
        dz, x = batched_args
        _, x_bdim = batch_dims
2774

2775
        out_bdims = x_bdim
2776
        return DActLuPrimitive.outer_primitive.bind(dz, x, act_enum=act_enum), out_bdims
2777
2778

    @staticmethod
2779
    def infer_sharding_from_operands(act_enum, mesh, arg_infos, result_infos):
2780
        """
2781
        dact_lu infer_sharding_from_operands
2782
        """
2783
2784
2785
        del result_infos, act_enum    # Unused.
        act_lu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*act_lu_out_spec))
2786
        return dx_sharding
2787

2788
    @staticmethod
2789
    def partition(act_enum, mesh, arg_infos, result_infos):
2790
        """
2791
        dact_lu partition
2792
        """
2793
        del result_infos, act_enum
2794
2795
2796
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding
2797
2798
2799
2800
2801

        def sharded_impl(dz, x):
            return DActLuPrimitive.impl(dz, x, act_enum=act_enum)

        return mesh, sharded_impl, out_shardings, arg_shardings
2802
2803


2804
register_primitive(DActLuPrimitive)
2805
2806


2807
2808
def dact_lu(inputs: jnp.ndarray, act_lu_inputs: jnp.ndarray,
            activation_type: Sequence[Union[str, Callable]]) -> jnp.ndarray:
2809
    """
2810
2811
    dact_lu fusion wrapper
    Return dgated_act_lu(inputs)
2812
    """
2813
2814
    act_type_id = ActivationEnum[activation_type]
    return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
2815
2816


2817
2818
def _normalize_axis_boundary(axis, ndim):
    return axis if axis >= 0 else ndim + axis
2819
2820


2821
def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
2822
    """
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
2841
    """
2842
2843
2844
2845
2846
2847
2848
2849
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes
    assert static_axis_boundary < len(shape) - 2    # at least 2 remaining for transpose.
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
    return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:],
            *shape[transpose_start_idx:transpose_axis_boundary])
2850
2851


2852
class CastTransposePrimitive(BasePrimitive):
2853
    """
2854
    Cast Transpose Primitive
2855
    """
2856
2857
2858
2859
2860
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None
2861
2862

    @staticmethod
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval
2882
2883

    @staticmethod
2884
2885
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
2886
        """
2887
        te_cast_transpose_p lowering rules
2888
        """
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(CastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 2})

        return out
2931
2932

    @staticmethod
2933
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
2934
        """
2935
        te_cast_transpose implementation
2936
        """
2937
2938
2939
2940
2941
2942
2943
        assert CastTransposePrimitive.inner_primitive is not None
        casted_x, casted_transposed_x, updated_amax = \
            CastTransposePrimitive.inner_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary)
        return casted_x, casted_transposed_x, updated_amax
2944

2945
2946
2947
2948
2949
2950
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
                transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2951

2952
2953
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims
2954

2955
2956
2957
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2958

2959
2960
2961
2962
2963
2964
2965
2966
2967
        out_bdims = x_bdim, x_bdim, amax_bdim
        return CastTransposePrimitive.outer_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2968

2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
                                     arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                  result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
            local_cx, local_cxt, local_updated_amax = \
                CastTransposePrimitive.impl(x, amax, scale, scale_inv,
                     out_dtype=out_dtype,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype, static_axis_boundary: int,
                   transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3011
    """
3012
3013
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
3014
    """
3015
3016
3017
3018
3019
3020
3021
3022
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary)
3023
3024


3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
class CastFP8Primitive(BasePrimitive):
    """
    Cast Primitive
    """
    name = "te_quantize"
    multiple_results = True
    impl_static_args = (4,)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
        """
        te_cast abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, updated_amax_aval

    @staticmethod
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
        """
        te_cast lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        opaque = transformer_engine_jax.pack_common_descriptor(ir_x_shape,
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(CastFP8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 1})

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype):
        """
        te_cast implementation
        """
        assert CastFP8Primitive.inner_primitive is not None
        casted_x, updated_amax = \
            CastFP8Primitive.inner_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype)
        return casted_x, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype):
        _check_valid_batch_dims(batch_dims)
        assert CastFP8Primitive.outer_primitive is not None

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

3109
        out_bdims = x_bdim, amax_bdim
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
        return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
                                                     out_dtype=out_dtype), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
            local_cx, local_updated_amax = \
                CastFP8Primitive.impl(x, amax, scale, scale_inv, out_dtype=out_dtype)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastFP8Primitive)


def cast_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
             out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Cast wrapper
    Return FP8 tensor
    """
    return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype)


3152
class TransposePrimitive(BasePrimitive):
3153
    """
3154
    Transpose Primitive
3155
    """
3156
    name = "te_transpose"
3157
    multiple_results = False
3158
3159
3160
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None
3161
3162

    @staticmethod
3163
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
3164
        """
3165
        _transpose abstract
3166
        """
3167
3168
3169
        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
3170

3171
        return xt_aval
3172
3173

    @staticmethod
3174
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
3175
        """
3176
        _transpose cuda lowering
3177
3178
        """

3179
3180
3181
3182
        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
            jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2
        ]
3183

3184
3185
3186
3187
3188
3189
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
3190

3191
3192
3193
3194
3195
3196
        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
        operands = [x]
        operand_shapes = [ir_x_shape]
3197
3198
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

3199
3200
3201
3202
3203
        te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype,
                                                               te_dtype)
3204

3205
        out = custom_caller(TransposePrimitive.name, args, opaque, False)
3206
3207
3208

        return [out]

3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
        transposed_x = \
            TransposePrimitive.inner_primitive.bind(x,
                                                    static_axis_boundary=static_axis_boundary,
                                                    transpose_axis_boundary=transpose_axis_boundary)
        return transposed_x
3220

3221
3222
3223
3224
3225
    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
3226

3227
3228
        x, = batched_args
        x_bdim, = batch_dims
3229

3230
3231
3232
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
3233

3234
3235
3236
3237
        out_bdims = x_bdim
        return TransposePrimitive.outer_primitive.bind(
            x, static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
3238
3239

    @staticmethod
3240
3241
3242
3243
3244
3245
3246
    def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding
3247
3248

    @staticmethod
3249
3250
3251
3252
3253
3254
3255
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding
3256

3257
3258
3259
        impl = partial(TransposePrimitive.impl,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
3260

3261
        return mesh, impl, out_shardings, arg_shardings
3262
3263


3264
register_primitive(TransposePrimitive)
3265
3266


3267
3268
def transpose(x: jnp.ndarray, static_axis_boundary: int,
              transpose_axis_boundary: int) -> jnp.ndarray:
3269
    """
3270
    transpose wrapper
3271
    """
3272
3273
3274
    return TransposePrimitive.outer_primitive.bind(x,
                                                   static_axis_boundary=static_axis_boundary,
                                                   transpose_axis_boundary=transpose_axis_boundary)
3275
3276


3277
class LayerNormFwdFp8Primitive(BasePrimitive):
3278
    """
3279
    Layer Normalization Forward FP8 Primitive
3280
    """
3281
3282
3283
3284
3285
    name = "te_layernorm_forward_fp8"
    multiple_results = True
    impl_static_args = (6, 7, 8)    # out_type, zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
3286
3287

    @staticmethod
3288
3289
    def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 zero_centered_gamma, epsilon):
3290
        """
3291
        LayerNorm fwd (fp8 out) inner primitive abstract
3292
        """
3293
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3294

3295
3296
3297
3298
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3299

3300
3301
3302
3303
        mu_rsigama_dtype = jnp.float32

        assert gamma_aval.size == beta_aval.size

3304
        wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
3305
3306
3307
3308
            x_aval.size // gamma_aval.size,    # batch size
            gamma_aval.size,    # hidden size
            jax_dtype_to_te_dtype(x_aval.dtype),    # in type
            jax_dtype_to_te_dtype(gamma_aval.dtype),    # weight type
3309
            jax_dtype_to_te_dtype(out_dtype),
3310
3311
3312
            True,
            zero_centered_gamma,
            epsilon)
3313

3314
3315
3316
        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3317
3318
3319
3320
3321
3322
        wkspace_aval = x_aval.update(shape=wkspace_info[0],
                                     dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
        barrier_aval = x_aval.update(shape=barrier_info[0],
                                     dtype=te_dtype_to_jax_dtype(barrier_info[1]))

        return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval
3323

3324
3325
3326
3327
3328
3329
3330
    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        LayerNorm fwd (fp8 out) outer primitive abstract
        """
        out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = \
            LayerNormFwdFp8Primitive.abstract(*args, **kwargs)
3331
        return out_aval, mu_aval, rsigma_aval, updated_amax_aval
3332
3333

    @staticmethod
3334
3335
    def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma,
                 epsilon):
3336
        """
3337
        LayerNorm fwd (fp8 out) lowering rules
3338
        """
3339
        x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3340

3341
3342
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3343

3344
3345
3346
3347
3348
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gamma_aval.dtype == beta_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3349

3350
3351
3352
3353
3354
3355
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
3356

3357
3358
        assert g_type == b_type
        assert g_shape == b_shape
3359

3360
3361
3362
3363
3364
3365
3366
3367
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3368

3369
3370
3371
3372
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3373

3374
3375
        wkspace_aval, barrier_aval = ctx.avals_out[-2:]

3376
3377
3378
3379
3380
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
3381
3382
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
            ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
3383
3384
3385
3386
3387
3388
        ]
        operands = [x, gamma, beta, amax, scale, scale_inv]
        operand_shapes = [
            x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3389

3390
3391
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3392
3393
3394
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
3395
3396
            wkspace_aval.size,
            barrier_aval.size,
3397
3398
            (0,),    # no dgamma_part in FWD pass
            (0,),    # no dbeta_part in BWD pass
3399
3400
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
3401
3402
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            jax_dtype_to_te_dtype(barrier_aval.dtype),
3403
3404
            TEDType.kByte,    # dummy dgamma_part te_dtype
            TEDType.kByte,    # dummy dbeta_part te_dtype
3405
3406
            zero_centered_gamma,
            epsilon,
3407
            sm_margin,
3408
        )
3409

3410
3411
3412
3413
3414
        out = custom_caller(LayerNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={3: 3})
3415

3416
        return out
3417
3418

    @staticmethod
3419
    def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon):
3420
        """
3421
        to describe implementation
3422
        """
3423
        assert LayerNormFwdFp8Primitive.inner_primitive is not None
3424
        out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind(
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
        return out, mu, rsigma, updated_amax
3435
3436

    @staticmethod
3437
    def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon):
3438
        """
3439
        to describe batch rules for vmap
3440
        """
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, beta, amax, scale, scale_inv = batched_args
        x_bdim, _, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
        return LayerNormFwdFp8Primitive.outer_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos,
                                     result_infos):
        del out_dtype, zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance.")

        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
3478
3479
        g_spec = get_padded_spec(arg_infos[1])
        b_spec = get_padded_spec(arg_infos[2])
3480
3481
        if x_spec[-1] is not None:
            warnings.warn(
3482
                f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " \
3483
3484
3485
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
        if g_spec[-1] is not None:
            warnings.warn(
                f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )
        if b_spec[-1] is not None:
            warnings.warn(
                f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )
3496
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
3497
3498
        g_sharding = NamedSharding(mesh, PartitionSpec(None))
        b_sharding = NamedSharding(mesh, PartitionSpec(None))
3499
3500
3501
3502
3503
3504
3505
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
3506

3507
3508
3509
3510
3511
3512
3513
        def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
            local_x, local_mu, local_rsigma, local_amax = \
                LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv,
                                            out_dtype=out_dtype,
                                            zero_centered_gamma=zero_centered_gamma,
                                            epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3514

3515
            return local_x, local_mu, local_rsigma, global_updated_amax
3516

3517
        return mesh, sharded_impl, out_shardings, arg_shardings
3518

3519
3520
3521
3522
3523
3524
3525

register_primitive(LayerNormFwdFp8Primitive)


def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
                      scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype,
                      zero_centered_gamma: bool, epsilon: float):
3526
    """
3527
    Wrapper for TE layernorm fwd (fp8 out)
3528
    """
3529
3530
3531
3532
3533
3534
3535
3536
3537
    return LayerNormFwdFp8Primitive.outer_primitive.bind(x,
                                                         gamma,
                                                         beta,
                                                         amax,
                                                         scale,
                                                         scale_inv,
                                                         out_dtype=out_dtype,
                                                         zero_centered_gamma=zero_centered_gamma,
                                                         epsilon=epsilon)
3538
3539


3540
class RmsNormFwdFp8Primitive(BasePrimitive):
3541
    """
3542
    RMS Normalization Forward FP8 Primitive
3543
    """
3544
3545
3546
3547
3548
    name = "te_rmsnorm_forward_fp8"
    multiple_results = True
    impl_static_args = (5, 6)    # out_dtype, epsilon
    inner_primitive = None
    outer_primitive = None
3549

3550
3551
    @staticmethod
    def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
3552
        """
3553
        RMSNorm fwd (fp8 out) inner primitive abstract
3554
        """
3555
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3556

3557
3558
3559
3560
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3561

3562
3563
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
3564

3565
        rsigama_dtype = jnp.float32
3566

3567
        wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes(
3568
            x_aval.size // hidden_size,    # batch_size
3569
            hidden_size,
3570
3571
3572
3573
3574
3575
            jax_dtype_to_te_dtype(x_aval.dtype),    # in te_dtype
            jax_dtype_to_te_dtype(gamma_aval.dtype),    # weight te_dtype
            jax_dtype_to_te_dtype(out_dtype),    # out te_dtype
            False,
            False,
            epsilon)
3576

3577
3578
3579
        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
        amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3580
3581
3582
3583
3584
3585
        wkspace_aval = x_aval.update(shape=wkspace_info[0],
                                     dtype=te_dtype_to_jax_dtype(wkspace_info[1]))
        barrier_aval = x_aval.update(shape=barrier_info[0],
                                     dtype=te_dtype_to_jax_dtype(barrier_info[1]))

        return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval
3586

3587
3588
3589
3590
3591
3592
    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        RMSNorm fwd (fp8 out) outer primitive abstract
        """
        out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs)
3593
        return out_aval, rsigma_aval, amax_aval
3594
3595

    @staticmethod
3596
    def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
3597
        """
3598
        RMSNorm fwd (fp8 out) lowering rules
3599
3600
        """

3601
3602
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3603

3604
        x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3605

3606
3607
3608
3609
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3610

3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3628

3629
3630
        wkspace_aval, barrier_aval = ctx.avals_out[-2:]

3631
3632
3633
3634
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
3635
3636
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
            ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))
3637
3638
3639
3640
3641
        ]
        operands = [x, gamma, amax, scale, scale_inv]
        operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

3642
3643
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3644
3645
3646
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
3647
3648
            wkspace_aval.size,
            barrier_aval.size,
3649
3650
            (0,),    # no dgamma_part in FWD pass
            (0,),    # no dbeta_part in BWD pass
3651
3652
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
3653
3654
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            jax_dtype_to_te_dtype(barrier_aval.dtype),
3655
3656
            TEDType.kByte,    # dummy dgamma_part te_dtype
            TEDType.kByte,    # dummy dbeta_part te_dtype
3657
3658
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
3659
            sm_margin,
3660
3661
        )

3662
3663
3664
3665
3666
3667
3668
3669
        out = custom_caller(RmsNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out

3670
    @staticmethod
3671
    def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
3672
        """
3673
        to describe implementation
3674
        """
3675
        assert RmsNormFwdFp8Primitive.inner_primitive is not None
3676
3677
3678
3679
3680
3681
3682
        out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
                                                                              gamma,
                                                                              amax,
                                                                              scale,
                                                                              scale_inv,
                                                                              out_dtype=out_dtype,
                                                                              epsilon=epsilon)
3683
        return out, rsigma, amax
3684

3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
        """
        to describe batch rules for vmap
        """
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
        out_bdims = x_bdim, x_bdim, amax_bdim
        return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                           gamma,
                                                           amax,
                                                           scale,
                                                           scale_inv,
                                                           out_dtype=out_dtype,
                                                           epsilon=epsilon), out_bdims
3702

3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
    @staticmethod
    def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del out_dtype, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, rsigma_sharding, amax_sharding)
3717

3718
3719
3720
3721
    @staticmethod
    def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
3722
        g_spec = get_padded_spec(arg_infos[1])
3723
3724
3725
3726
3727
3728
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
3729
3730
3731
3732
3733
        if g_spec[-1] is not None:
            warnings.warn(
                f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \
                f"Enforcing no sharding of parameters hidden dim! " \
            )
3734
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
3735
        g_sharding = NamedSharding(mesh, PartitionSpec(None))
3736
3737
3738
3739
3740
3741
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
3742

3743
3744
3745
3746
3747
        def sharded_impl(x, gamma, amax, scale, scale_inv):
            local_x, local_rsigma, local_amax= \
                RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv,
                                            out_dtype=out_dtype, epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3748

3749
            return local_x, local_rsigma, global_updated_amax
3750

3751
        return mesh, sharded_impl, out_shardings, arg_shardings
3752
3753


3754
register_primitive(RmsNormFwdFp8Primitive)
3755

3756
3757
3758

def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                    scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float):
3759
    """
3760
    Wrapper for TE rmsnorm fwd (fp8 out)
3761
    """
3762
3763
3764
3765
3766
3767
3768
    return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                       gamma,
                                                       amax,
                                                       scale,
                                                       scale_inv,
                                                       out_dtype=out_dtype,
                                                       epsilon=epsilon)
3769
3770


3771
class ActLuFp8Primitive(BasePrimitive):
3772
    """
3773
    ActLu FP8 Primitive
3774
    """
3775
    name = "te_act_lu_fp8"
3776
    multiple_results = True
3777
    impl_static_args = (4, 5)    #out_dtype, act_enum
3778
3779
3780
3781
    inner_primitive = None
    outer_primitive = None

    @staticmethod
3782
3783
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 act_enum):  # pylint: disable=unused-argument
3784
        """
3785
        te_act_lu_p abstract
3786
3787
3788
3789
3790
3791
3792
3793
3794
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

3795
3796
3797
3798
3799
        assert (x_aval.shape[-2] == 1 or x_aval.shape[-2] == 2)
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
3800
3801
3802
3803
3804
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, updated_amax_aval

    @staticmethod
3805
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum):
3806
        """
3807
        te_gated_act_lu_p lowering rules
3808
3809
3810
3811
3812
3813
3814
3815
3816
3817
3818
3819
3820
3821
3822
3823
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        hidden_size = ir_x_shape[-1]
3824
3825
3826
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
3827
3828
3829
3830
3831
3832
3833
3834
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

3835
3836
3837
3838
3839
        opaque = transformer_engine_jax.pack_common_descriptor((
            batch_size, hidden_size),
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
            act_enum)
3840

3841
        out = custom_caller(ActLuFp8Primitive.name,
3842
3843
3844
3845
3846
3847
3848
3849
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 1})

        return out

    @staticmethod
3850
    def impl(x, amax, scale, scale_inv, out_dtype, act_enum):
3851
3852
3853
        """
        to describe implementation
        """
3854
3855
3856
3857
3858
3859
3860
        assert ActLuFp8Primitive.inner_primitive is not None
        out, updated_amax = ActLuFp8Primitive.inner_primitive.bind(x,
                                                                   amax,
                                                                   scale,
                                                                   scale_inv,
                                                                   out_dtype=out_dtype,
                                                                   act_enum=act_enum)
3861
3862
3863
        return out, updated_amax

    @staticmethod
3864
    def batcher(batched_args, batch_dims, *, out_dtype, act_enum):
3865
3866
3867
3868
        """
        to describe batch rules for vmap
        """
        _check_valid_batch_dims(batch_dims)
3869
        assert ActLuFp8Primitive.outer_primitive is not None
3870
3871
3872
3873
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, amax_bdim
3874
3875
3876
        return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv,
                                                      out_dtype=out_dtype,
                                                      act_enum=act_enum), out_bdims
3877
3878

    @staticmethod
3879
3880
    def infer_sharding_from_operands(out_dtype, act_enum, mesh, arg_infos, result_infos):
        del out_dtype, result_infos, act_enum
3881
        x_spec = get_padded_spec(arg_infos[0])
3882
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
3883
3884
3885
3886
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)

    @staticmethod
3887
    def partition(out_dtype, act_enum, mesh, arg_infos, result_infos):
3888
3889
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
3890
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
3891
3892
3893
3894
3895
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
3896
3897
3898
3899
3900
3901
            local_x, local_amax = ActLuFp8Primitive.impl(x,
                                                         amax,
                                                         scale,
                                                         scale_inv,
                                                         out_dtype=out_dtype,
                                                         act_enum=act_enum)
3902
3903
3904
3905
3906
3907
3908
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)

            return local_x, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


3909
register_primitive(ActLuFp8Primitive)
3910
3911


3912
3913
3914
def act_lu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
             out_dtype: jnp.dtype, activation_type: Sequence[Union[str, Callable]]
               ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3915
    """
3916
3917
3918
3919
    act wrapper
    Return FP8(act_lu(x))
    Input shape: (N, 1, H) for non-gated activations
                 (N, 2, H) for gated activations
3920
    """
3921
3922
3923
    act_type_id = ActivationEnum[activation_type]
    return ActLuFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype,
                                                  act_enum = act_type_id)
3924
3925


3926
class DActLuDBiasCastTransposePrimitive(BasePrimitive):
3927
    """
3928
    DActLu DBias Cast Transpose Primitive
3929
    """
3930
    name = "te_dact_lu_dbias_cast_transpose"
3931
    multiple_results = True
3932
3933
    # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum
    impl_static_args = (5, 6, 7, 8)
3934
3935
3936
3937
3938
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
3939
3940
                 static_axis_boundary, transpose_axis_boundary,
                 act_enum):  # pylint: disable=unused-argument
3941
        """
3942
        te_dact_lu_dbais_cast_transpose_p abstract
3943
3944
3945
3946
3947
3948
3949
3950
3951
3952
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
3953
3954
        t_shape = _multidim_transpose(x_aval.shape,
                                      static_axis_boundary, transpose_axis_boundary)
3955
3956
3957
3958
3959
3960
3961
3962
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

        dbias_shape = (*x_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

3963
        wkspace_info, = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
3964
3965
3966
3967
3968
3969
3970
3971
3972
3973
3974
3975
3976
            x_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
        )
        wkspace_aval = x_aval.update(shape=wkspace_info[0],
                                     dtype=te_dtype_to_jax_dtype(wkspace_info[1]))

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
3977
        te_dact_lu_dbais_cast_transpose_p outer abstract
3978
3979
3980
        """

        out, t_out, dbias, updated_amax_aval, _ = \
3981
            DActLuDBiasCastTransposePrimitive.abstract(*args, **kwargs)
3982
3983
3984
3985
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
3986
                 transpose_axis_boundary, act_enum):
3987
        """
3988
        te_dgated_act_lu_cast_transpose_p lowering rules
3989
3990
3991
3992
3993
3994
3995
3996
3997
3998
3999
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
4000
4001
4002
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
4003
        ir_hidden_szie = ir_dz_shape[-1]
4004
        contracted_x_shape = (x_batch_size, ir_hidden_szie)
4005
4006
4007
4008
4009
4010
4011
4012
4013
4014
4015
4016
4017
4018
4019
4020
4021
4022
4023
4024
4025
4026
4027
4028
4029

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)
        dbias_shape = (*x_shape[:static_axis_boundary + 1], ir_hidden_szie)

        wkspace_aval = ctx.avals_out[-1]

        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_common_wk_descriptor(
            contracted_x_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
4030
4031
            jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype),
            act_enum)
4032

4033
        out = custom_caller(DActLuDBiasCastTransposePrimitive.name,
4034
4035
4036
4037
4038
4039
4040
4041
4042
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 3})

        return out

    @staticmethod
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary,
4043
             transpose_axis_boundary, act_enum):
4044
4045
4046
        """
        to describe implementation
        """
4047
4048
        assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
4049
4050
4051
4052
4053
4054
4055
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
4056
4057
            transpose_axis_boundary=transpose_axis_boundary,
            act_enum=act_enum)
4058
4059
4060
4061
        return out, t_out, dbias, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
4062
                transpose_axis_boundary, act_enum):
4063
4064
4065
4066
4067
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
4068
        assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
4069
4070
4071
4072
4073
4074
4075
4076
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
4077
        return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
4078
4079
4080
4081
4082
4083
4084
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=x_bdim,
4085
4086
            transpose_axis_boundary=transpose_axis_boundary,
            act_enum=act_enum), out_bdims
4087
4088

    @staticmethod
4089
4090
4091
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary,
                                     act_enum, mesh, arg_infos, result_infos):
        del out_dtype, result_infos, act_enum
4092
4093
4094
4095
4096
4097
4098
4099
4100
4101
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
4102
4103
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary,
                  act_enum, mesh, arg_infos, result_infos):
4104
4105
4106
4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
4117
4118
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
                         amax_sharding)

        def sharded_impl(dz, x, amax, scale, scale_inv):
4119
4120
            local_out, local_t_out, local_dbias, local_amax =\
            DActLuDBiasCastTransposePrimitive.impl(
4121
4122
4123
4124
4125
4126
4127
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
4128
4129
                transpose_axis_boundary=transpose_axis_boundary,
                act_enum=act_enum)
4130
4131
4132
4133
4134
4135
4136
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


4137
register_primitive(DActLuDBiasCastTransposePrimitive)
4138
4139


4140
4141
4142
4143
4144
4145
4146
4147
4148
4149
4150
def dact_lu_dbias_cast_transpose(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
    transpose_axis_boundary: int = -1,
    activation_type: Sequence[Union[str, Callable]] = ('gelu',)
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
4151
    """
4152
4153
4154
    cast transpose dact_lu and dbias fusion wrapper
    Return FP8(dact_lu(inputs)), dbias
    ONLY support non-gated activation type
4155
4156
4157
4158
    """
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes

4159
4160
    act_type_id = ActivationEnum[activation_type]
    return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
4161
4162
4163
4164
4165
4166
4167
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
4168
4169
        transpose_axis_boundary=transpose_axis_boundary,
        act_enum=act_type_id)
4170
4171


4172
4173
4174
4175
4176
4177
4178
4179
4180
4181
4182
4183
4184
4185
4186
4187
4188
4189
4190
4191
4192
4193
class DBiasCastTransposePrimitive(BasePrimitive):
    """
    DBias Cast Transpose Primitive
    """
    name = "te_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 static_axis_boundary, transpose_axis_boundary):
        """
        te_dbias_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
4194
        gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
4195
4196
4197
4198
4199
4200
4201
4202
4203
4204
4205
4206
4207
4208
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
4223
4224
4225
4226
4227
4228
4229
4230
4231
4232
4233
4234
4235
4236
        t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
        out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

        dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size)
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes(
            dz_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype)
        )
        wkspace_aval = dz_aval.update(shape=wkspace_info[0],
                                     dtype=te_dtype_to_jax_dtype(wkspace_info[1]))

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_cast_transpose_p outer abstract
        """

        out, t_out, dbias, updated_amax_aval, _ = \
        DBiasCastTransposePrimitive.abstract(*args, **kwargs)
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
    def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
        """
        te_dbias_cast_transpose_p lowering rules
        """
        dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
4237
4238
4239
        batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
        ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
        contracted_dz_shape = (batch_size, ir_hidden_size)
4240
4241
4242
4243
4244
4245
4246
4247
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary,
                                                 transpose_axis_boundary)
4248
        dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_size)
4249
4250
4251
4252
4253
4254
4255
4256
4257
4258
4259
4260
4261
4262
4263
4264
4265
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
4280
4281
4282
4283
4284
4285
4286
4287
4288
4289
4290
4291
4292
4293
4294
4295
4296
4297
4298
4299
4300
4301
4302
4303
4304
4305
4306
4307
4308
4309
4310
4311
4312
4313
4314
4315
4316
4317
4318
4319
4320

        wkspace_aval = ctx.avals_out[-1]

        out_types = [
            ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
            ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
        ]
        operands = [dz, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_common_wk_descriptor(
            contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype))

        out = custom_caller(DBiasCastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 3})

        return out

    @staticmethod
    def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary,
             transpose_axis_boundary):
        """
        to describe implementation
        """
        assert DBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary)
        return out, t_out, dbias, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
                transpose_axis_boundary):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
        assert DBiasCastTransposePrimitive.outer_primitive is not None
        dz, amax, scale, scale_inv = batched_args
        dz_bdim, _, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim

        out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
        return DBiasCastTransposePrimitive.outer_primitive.bind(
            dz,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=dz_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
                                     arg_infos, result_infos):
        del out_dtype, result_infos
4321
        x_spec = get_padded_spec(arg_infos[0])
4322
4323
4324
4325
4326
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))
4327
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
4328
4329
4330
4331
4332
4333
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                  result_infos):
        del result_infos
4334
        x_spec = get_padded_spec(arg_infos[0])
4335
4336
4337
4338
4339
4340
4341
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1]))

4342
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
4343
4344
4345
4346
4347
4348
4349
4350
4351
4352
4353
4354
4355
4356
4357
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
4368
4369
4370
4371
4372
4373
4374
4375
4376
4377
4378
4379
4380
4381
4382
4383
4384
4385
4386
4387
4388
4389
4390
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding,
                         amax_sharding)

        def sharded_impl(dz, amax, scale, scale_inv):
            local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary)
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DBiasCastTransposePrimitive)


def dbias_cast_transpose(
    dz: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
    transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    cast transpose dbias partial fusion wrapper
    Return FP8(inputs), dbias
    """
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes

    return DBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary)


4391
class DgatedActLuCastTransposePrimitive(BasePrimitive):
4392
    """
4393
    Dgated ActLu Cast Transpose Primitive
4394
    """
4395
    name = "te_dgated_act_lu_cast_transpose"
4396
    multiple_results = True
4397
    impl_static_args = (5, 6, 7)    # out_dtype, static_axis_boundary, act_enum
4398
4399
    inner_primitive = None
    outer_primitive = None
4400
4401

    @staticmethod
4402
4403
    def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 static_axis_boundary, act_enum):  # pylint: disable=unused-argument
4404
        """
4405
        te_dgated_act_lu_cast_transpose_p abstract
4406
        """
4407
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
4408
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
4409
4410
        assert x_aval.dtype == dtype
        assert x_aval.shape[-2] == 2    # Linear + GeLU
4411
4412
4413
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
4414
4415
4416
4417
4418
4419
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
4420
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
4421
        return out, t_out, updated_amax_aval
4422
4423

    @staticmethod
4424
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
4425
        """
4426
        te_dgated_act_lu_cast_transpose_p lowering rules
4427
        """
4428
4429
4430
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
4431
4432
4433
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
4434
4435
4436
4437
4438
4439
4440
4441
4442
4443
4444
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        assert x_shape[-2] == 2    # Linear + GeLU
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
4445
4446
4447
4448
4449
4450
4451
4452
4453
4454
4455
4456
4457
4458
4459
4460
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
4461
4462
4463
4464
4465
        opaque = transformer_engine_jax.pack_common_descriptor(
            contracted_x_shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
            act_enum)
4466

4467
        out = custom_caller(DgatedActLuCastTransposePrimitive.name,
4468
4469
4470
4471
4472
4473
4474
4475
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out

    @staticmethod
4476
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
4477
4478
4479
        """
        to describe implementation
        """
4480
4481
        assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
4482
4483
4484
4485
4486
4487
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
4488
4489
            static_axis_boundary=static_axis_boundary,
            act_enum=act_enum)
4490
4491
4492
        return out, t_out, updated_amax

    @staticmethod
4493
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
4494
4495
4496
4497
4498
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
4499
        assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
4500
4501
4502
4503
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, amax_bdim
4504
        return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
4505
            dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
4506
4507
            static_axis_boundary=x_bdim,
            act_enum=act_enum), out_bdims
4508
4509

    @staticmethod
4510
4511
4512
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, act_enum,
                                     mesh, arg_infos, result_infos):
        del out_dtype, result_infos, act_enum
4513
4514
4515
4516
4517
4518
4519
4520
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)

    @staticmethod
4521
4522
    def partition(out_dtype, static_axis_boundary, act_enum,
                  mesh, arg_infos, result_infos):
4523
4524
4525
4526
4527
4528
4529
4530
4531
4532
4533
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(dz, x, amax, scale, scale_inv):
4534
            local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
4535
4536
4537
4538
4539
4540
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
4541
4542
                static_axis_boundary=static_axis_boundary,
                act_enum=act_enum)
4543
4544
4545
4546
4547
4548
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


4549
register_primitive(DgatedActLuCastTransposePrimitive)
4550
4551


4552
4553
4554
4555
4556
4557
def dgated_act_lu_cast_transpose(
    dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
    scale_inv: jnp.ndarray, out_dtype: TEDType,
    static_axis_boundary: int,
    activation_type: Sequence[Union[str, Callable]]
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
4558
    """
4559
4560
    cast transpose d_gated_act_lu fusion wrapper
    Return FP8(dgated_act_lu(inputs))
4561
    """
4562
4563
    act_type_id = ActivationEnum[activation_type]
    return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
4564
4565
4566
4567
4568
4569
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
4570
4571
        static_axis_boundary=static_axis_boundary,
        act_enum=act_type_id)