cpp_extensions.py 145 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te custom call"""

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
11
12
import warnings

13
14
15
16
17
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.interpreters import xla, mlir
18
from jax.experimental.custom_partitioning import custom_partitioning
19
from jax.interpreters.mlir import ir, dtype_to_ir_type
20
21
from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
22

23
24
25
26
27
28
29
try:
    from jaxlib.hlo_helpers import custom_call
except ImportError:
    # Newer JAX changed its API. But we want to support a few JAX
    # version, so we still need this import.
    pass

30
31
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
32
33
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
34
35
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
36

37
38
39
40
41
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec

42
43
44
45
46
47
48
49
50
for _name, _value in transformer_engine_jax.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="CUDA")


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)
66
67
68
69
70
71
72
73
74


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


75
76
77
78
79
80
81
def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


82
83
84
85
def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
86
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
87

88
89
90
91
92
93
94
95
96
    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
    }
97

98
99
    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")
100

101
    return converter.get(jax_dtype)
102
103


104
105
106
107
108
109
110
111
def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)
112
113


114
def _check_valid_batch_dims(bdims):
115
    """
116
    Assert out non-supported bath dims
117
    """
118
119
120
121
    for dim in bdims:
        assert dim in [0, None], \
            "Currently only support batch_dim in [0, None], " \
            f"but got {dim=}"
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


class BasePrimitive(metaclass=ABCMeta):
    """
    jax premitive
    """

    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

177
178
179
180
181

def register_primitive(cls):
    """
    register jax primitive
    """
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


@dataclass
class CustomCallArgsWrapper:
    """
    wrapper of XLA custom call args
    """

    def __init__(self,
                 output_types,
                 operands,
                 operand_shapes,
                 operand_specific_layouts=None,
                 output_specific_layouts=None):
        self.output_types = output_types
        self.operands = operands
        self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
                                                                      operand_specific_layouts)
        output_shapes = [x.shape for x in output_types]
        self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
                                                                     output_specific_layouts)

    @staticmethod
    def generate_layouts(shapes, specific_layouts):
        """
        setup layouts for XLA custom call
        """

        def default_layout(shape):
            return range(len(shape) - 1, -1, -1)

        if specific_layouts is None:
            specific_layouts = {}

        layouts = []
        for idx, shape in enumerate(shapes):
            if idx in specific_layouts:
                layouts.append(specific_layouts[idx])
            else:
                layouts.append(default_layout(shape))
        return layouts


def custom_caller(name, args, opaque, has_side_effect, **kwargs):
    """
    XLA custom call warpper
    """
251
252
253
254
255
256
257
258
259
260
261
262
263
    if hasattr(mlir, "custom_call"):
        out = mlir.custom_call(name,
                               result_types=args.output_types,
                               operands=args.operands,
                               operand_layouts=args.operand_layouts,
                               result_layouts=args.output_layouts,
                               backend_config=opaque,
                               has_side_effect=has_side_effect,
                               **kwargs).results
    else:
        # Need to disable one pylint error as the second function
        # parameter name recenctly in JAX. Otherwise we won't be
        # compatible with multiple JAX version.
264
265
266
267
268
269
270
271
272
        out = custom_call(    # pylint: disable=too-many-function-args
            name,
            args.output_types,
            operands=args.operands,
            operand_layouts=args.operand_layouts,
            result_layouts=args.output_layouts,
            backend_config=opaque,
            has_side_effect=has_side_effect,
            **kwargs)
273
274
275
    return out


276
class LayerNormFwdPrimitive(BasePrimitive):
277
    """
278
    Layer Normalization Forward Primitive
279
    """
280
281
282
283
284
    name = "te_layernorm_forward"
    multiple_results = True
    impl_static_args = (3, 4)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
285
286

    @staticmethod
287
    def abstract(x_aval, gamma_aval, beta_aval, **kwargs):    # pylint: disable=unused-argument
288
        """
289
        LayerNorm fwd abstract
290
        """
291
292
293
294
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        mu_rsigama_dtype = jnp.float32
295

296
297
        out_aval = core.raise_to_shaped(x_aval)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
298

299
300
301
302
303
        assert gamma_aval.size == beta_aval.size
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0

        return out_aval, mu_aval, rsigma_aval
304
305

    @staticmethod
306
    def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
307
        """
308
        LayerNorm fwd lowering rules
309
        """
310
311
312
313
314
315
316
317
        x_aval, gamma_aval, beta_aval = ctx.avals_in
        assert gamma_aval.dtype == beta_aval.dtype
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
318

319
320
        assert g_type == b_type
        assert g_shape == b_shape
321

322
323
324
325
326
        # Output shape is same as the input shape, but the output type is same as the weight type.
        # See ln_api.cpp
        output_type = g_type.element_type
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
327

328
329
330
331
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
332

333
334
335
336
337
338
339
340
        out_types = [
            ir.RankedTensorType.get(out_shape, output_type),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
        ]
        operands = [x, gamma, beta]
        operand_shapes = [x_shape, g_shape, b_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
341

342
343
344
345
346
347
348
349
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
        )
350

351
        out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
352

353
        return out
354
355

    @staticmethod
356
    def impl(x, gamma, beta, zero_centered_gamma, epsilon):
357
        """
358
        to describe implementation
359
        """
360
361
362
363
364
365
366
        assert LayerNormFwdPrimitive.inner_primitive is not None
        out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind(
            x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return out, mu, rsigma

    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
367
        """
368
        to describe batch rules for vmap
369
        """
370
371
372
373
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdPrimitive.outer_primitive is not None
        x, gamma, beta = batched_args
        x_bdim, _, _ = batch_dims
374

375
376
377
378
379
380
        out_bdims = x_bdim, x_bdim, x_bdim
        return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                          gamma,
                                                          beta,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
381

382
383
384
385
386
387
388
389
390
391
392
393
394
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, mu_sharding, rsigma_sharding)
395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
411

412
413
414
415
416
417
        arg_shardings = (x_sharding, g_sharding, b_sharding)
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
        impl = partial(LayerNormFwdPrimitive.impl,
                       zero_centered_gamma=zero_centered_gamma,
                       epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
418
419


420
register_primitive(LayerNormFwdPrimitive)
421
422


423
424
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
                  epsilon: float):
425
    """
426
    Wrapper for TE layernorm fwd
427
    """
428
429
430
431
432
    return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                      gamma,
                                                      beta,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
433
434


435
class LayerNormBwdPrimitive(BasePrimitive):
436
    """
437
    Layer Normalization Backward Primitive
438
    """
439
440
441
442
443
    name = "te_layernorm_backward"
    multiple_results = True
    impl_static_args = (5, 6)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
444
445

    @staticmethod
446
    def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
447
        """
448
        Layernorm bwd abstract
449
        """
450
451
452
453
454
455
456
457
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
        assert mu_dtype == rsigma_dtype == jnp.float32
458

459
460
461
        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval, dbeta_aval
462
463

    @staticmethod
464
    def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
465
        """
466
        Layernorm bwd lowering rules
467
        """
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        _, x_aval, _, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(gamma.type)
        b_shape = b_type.shape
        assert g_type == b_type
        assert g_shape == b_shape

        dz_shape = ir.RankedTensorType(dz.type).shape
        mu_shape = ir.RankedTensorType(mu.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
484
485

        out_types = [
486
487
488
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
            ir.RankedTensorType.get(b_shape, b_type.element_type),
489
        ]
490
491
        operands = [dz, mu, rsigma, x, gamma]
        operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
492
493
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

494
495
496
497
498
499
500
501
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
        )
502

503
        out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
504

505
        return out
506

507
508
509
510
511
512
    @staticmethod
    def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
        assert LayerNormBwdPrimitive.inner_primitive is not None
        dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind(
            dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return dx, dgamma, dbeta
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert LayerNormBwdPrimitive.outer_primitive is not None
        dz, x, mu, rsigma, gamma = batched_args
        _, x_bdim, _, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim, gamma_bdim
        return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                          x,
                                                          mu,
                                                          rsigma,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
529

530
531
532
533
534
535
536
537
538
539
540
541
542
543
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        return dx_sharding, dgamma_sharding, dbeta_sharding
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
        arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec)))

        def sharded_impl(dz, x, mu, rsigma, gamma):
            local_dx, local_dgamma, local_dbeta = \
                LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma,
                     zero_centered_gamma=zero_centered_gamma,
                     epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
            return local_dx, global_dgamma, global_dbeta

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(LayerNormBwdPrimitive)


def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray,
                  gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
580
    """
581
    Wrapper for TE layernorm bwd
582
    """
583
584
585
586
587
588
589
    return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                      x,
                                                      mu,
                                                      rsigma,
                                                      gamma,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
590
591


592
class RmsNormFwdPrimitive(BasePrimitive):
593
    """
594
    RMS Normalization Forward Primitive
595
    """
596
    name = "te_rmsnorm_forward"
597
    multiple_results = True
598
599
600
    impl_static_args = (2,)    # epsilon
    inner_primitive = None
    outer_primitive = None
601
602

    @staticmethod
603
    def abstract(x_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
604
        """
605
        RMSNorm fwd abstract
606
        """
607
608
609
610
611
612
613
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        rsigama_dtype = jnp.float32

        out_aval = core.raise_to_shaped(x_aval)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
614

615
616
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
617

618
        return out_aval, rsigma_aval
619
620

    @staticmethod
621
    def lowering(ctx, x, gamma, *, epsilon):
622
        """
623
        RMSNorm fwd lowering rules
624
        """
625
626
627
628
629
630
631
632
633
634
635
        x_aval, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        rsigma_element_type = ir.F32Type.get()

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
636
637

        out_types = [
638
639
            ir.RankedTensorType.get(out_shape, x_type.element_type),
            ir.RankedTensorType.get(batch_shape, rsigma_element_type),
640
        ]
641
642
        operands = [x, gamma]
        operand_shapes = [x_shape, g_shape]
643
644
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

645
646
647
648
649
650
651
652
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
        )
653

654
        out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
655
656
657
658

        return out

    @staticmethod
659
    def impl(x, gamma, epsilon):
660
        """
661
        to describe implementation
662
        """
663
664
665
        assert RmsNormFwdPrimitive.inner_primitive is not None
        out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
        return out, rsigma
666
667

    @staticmethod
668
    def batcher(batched_args, batch_dims, *, epsilon):
669
        """
670
        to describe batch rules for vmap
671
        """
672
673
674
675
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdPrimitive.outer_primitive is not None
        x, gamma = batched_args
        x_bdim, _ = batch_dims
676

677
678
        out_bdims = x_bdim, x_bdim
        return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
679

680
681
682
683
684
685
686
687
688
689
690
691
692
    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, rsigma_sharding)
693

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (x_sharding, g_sharding)
        out_shardings = (out_sharding, rsigma_sharding)
        impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
712
713


714
register_primitive(RmsNormFwdPrimitive)
715
716


717
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
718
    """
719
    Wrapper for TE rmsnorm fwd
720
    """
721
    return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
722
723


724
class RmsNormBwdPrimitive(BasePrimitive):
725
    """
726
    RMS Normalization Backward Primitive
727
    """
728
    name = "te_rmsnorm_backward"
729
    multiple_results = True
730
731
732
    impl_static_args = (4,)    # epsilon
    inner_primitive = None
    outer_primitive = None
733
734

    @staticmethod
735
736
737
738
739
740
741
    def abstract(
            dz_aval,
            x_aval,
            rsigma_aval,
            gamma_aval,
            **kwargs    # pylint: disable=unused-argument
    ):
742
        """
743
        RMSNorm bwd abstract
744
        """
745
746
747
748
749
750
751
752
753
754
755
756
757
758
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert rsigma_aval.shape == x_aval.shape[:-1]
        assert rsigma_dtype == jnp.float32

        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval

    @staticmethod
    def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
759
        """
760
        RMSNorm bwd lowering rules
761
        """
762
763
764
765
766
767
768
769
770
771
        _, x_aval, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        dz_shape = ir.RankedTensorType(dz.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
772
773

        out_types = [
774
775
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
776
        ]
777
778
        operands = [dz, rsigma, x, gamma]
        operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
779
780
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

781
782
783
784
785
786
787
788
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
        )
789

790
        out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
791
792
793

        return out

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
    @staticmethod
    def impl(dz, x, rsigma, gamma, epsilon):
        assert RmsNormBwdPrimitive.inner_primitive is not None
        dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
        return dx, dgamma

    @staticmethod
    def batcher(batched_args, batch_dims, *, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert RmsNormBwdPrimitive.outer_primitive is not None
        dz, x, rsigma, gamma = batched_args
        _, x_bdim, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim
        return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma,
                                                        epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        return dx_sharding, dgamma_sharding

    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_shardings = dx_sharding, dgamma_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec)))

        def sharded_impl(dz, x, rsigma, gamma):
            local_dx, local_dgamma = \
                RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            return local_dx, global_dgamma

        return mesh, sharded_impl, out_shardings, arg_shardings

852

853
register_primitive(RmsNormBwdPrimitive)
854
855


856
857
def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray,
                epsilon: float):
858
    """
859
    Wrapper for TE layernorm bwd
860
    """
861
    return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
862
863


864
class SoftmaxPrimitive(BasePrimitive):
865
    """
866
    Softmax Primitive
867
    """
868
    max_k_seqlen_supported = 4096
869
870

    @staticmethod
871
872
873
874
875
    @abstractmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError
876

877
878
879
880
881
    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
        threads_per_block = 128    # Depends on the kernel implmentation
882

883
884
885
886
887
888
        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block
889
890

    @staticmethod
891
    def forward_abstract(logits_aval, scale_factor):
892
        """
893
        softmax_forward abstract
894
        """
895
896
897
898
899
900
901
902
903
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
904

905
906
        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval
907

908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
        i_aval, = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        pad_batch = batch
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits]
        operand_shapes = [i_shape]
926
927
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

928
929
930
931
        opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
                                                                k_seqlen,
                                                                jax_dtype_to_te_dtype(i_aval.dtype),
                                                                scale_factor)
932

933
        out = custom_caller(name, args, opaque, False)
934
935
936

        return [out]

937
938
939
940
941
942
943
944
    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output
945

946
947
948
949
950
951
952
953
    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
        logits, = batched_args
        logits_bdim, = batch_dims
954

955
956
        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims
957

958
959
960
961
962
963
964
965
966
    @staticmethod
    def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding
967
968

    @staticmethod
969
    def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
970
        """
971
        softmax_forward partitioning
972
        """
973
974
975
976
977
978
979
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        out_spec = logits_spec
        arg_shardings = (logits_spec,)
        out_shardings = out_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
980

981
982
983
984
985
986
987
988
989
990
    @staticmethod
    def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None):    # pylint: disable=unused-argument
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]
991

992
        assert dz_aval.shape == softmax_out_aval.shape
993

994
995
        dx_aval = core.raise_to_shaped(softmax_out_aval)
        return dx_aval
996
997

    @staticmethod
998
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
999
        """
1000
        softmax_backward lowering rules
1001
        """
1002
        dz_aval, _ = ctx.avals_in
1003

1004
1005
        dz_type = ir.RankedTensorType(dz.type)
        dz_shape = dz_type.shape
1006

1007
1008
1009
1010
1011
1012
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, dz_shape[:-3])
        pad_batch = batch    # unused
        heads = dz_shape[-3]
        q_seqlen = dz_shape[-2]
        k_seqlen = dz_shape[-1]
1013

1014
1015
        softmax_out_type = ir.RankedTensorType(softmax_out.type)
        softmax_out_shape = softmax_out_type.shape
1016

1017
1018
1019
        out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)]
        operands = [dz, softmax_out]
        operand_shapes = [dz_shape, softmax_out_shape]
1020
1021
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1022
1023
1024
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype),
            scale_factor)
1025

1026
        out = custom_caller(name, args, opaque, False)
1027

1028
        return [out]
1029
1030

    @staticmethod
1031
    def backward_impl(primitive, dz, softmax_out, scale_factor):
1032
        """
1033
        softmax_backward implementation
1034
        """
1035
1036
1037
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx
1038

1039
1040
1041
1042
1043
1044
1045
1046
    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims
1047

1048
1049
        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
1050
1051

    @staticmethod
1052
    def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1053
        """
1054
        softmax_backward infer_sharding_from_operands
1055
        """
1056
1057
1058
1059
        del scale_factor, result_infos    # Unused.
        softmax_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec))
        return dx_sharding
1060

1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
    @staticmethod
    def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos
        dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        dx_spec = softmax_out_spec
        arg_shardings = (dz_spec, softmax_out_spec)
        out_shardings = dx_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
1074
1075


1076
1077
1078
1079
1080
1081
1082
1083
1084
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
    name = "te_scaled_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None
1085

1086
1087
1088
1089
1090
    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads
1091

1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False
1103

1104
1105
1106
1107
1108
1109
    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)
1120

1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits,
                                             scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive,
                                                batched_args,
                                                batch_dims,
                                                scale_factor=scale_factor)
1133

1134
1135
1136
1137
1138
1139
1140
1141
1142
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor,
                                                  mesh, arg_infos, result_infos)
1143
1144


1145
register_primitive(ScaledSoftmaxFwdPrimitive)
1146

1147
1148

def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
1149
    """
1150
1151
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
1152
    """
1153
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
1154
1155


1156
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
1157
    """
1158
    Scaled Softmax Bwd Primitive
1159
    """
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
    name = "te_scaled_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)
1172
1173

    @staticmethod
1174
    def abstract(dz_aval, softmax_out_aval, scale_factor):
1175
        """
1176
        te_scaled_softmax_backward abstract
1177
        """
1178
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)
1179

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1190

1191
        return out
1192

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1212
1213

    @staticmethod
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor,
                                                   mesh, arg_infos, result_infos)


register_primitive(ScaledSoftmaxBwdPrimitive)


def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                       scale_factor: float) -> jnp.ndarray:
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                          softmax_out,
                                                          scale_factor=scale_factor)


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, mask_aval, scale_factor):    # pylint: disable=unused-argument
1263
        """
1264
        te_scaled_masked_softmax_forward abstract
1265
1266
        """

1267
1268
1269
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
1270

1271
1272
1273
1274
1275
1276
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
1277

1278
1279
1280
        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
1281
        ]
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
        assert pad_batch in (1, batch)    # 1 means broadcast
        assert mask_shape[-3] == 1    # 1 means broadcast
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """

        logits_aval, _ = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        mask_type = ir.RankedTensorType(mask.type)
        mask_shape = mask_type.shape
        pad_batch = reduce(operator.mul, mask_shape[:-3])

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits, mask]
        operand_shapes = [i_shape, mask_shape]
1314
1315
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1316
1317
1318
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype),
            scale_factor)
1319

1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
        out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

        return [out]

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits,
                                                                      mask,
                                                                      scale_factor=scale_factor)
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
        return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
            logits, mask, scale_factor=scale_factor), out_bdims

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = (logits_spec, mask_spec)
        out_shardings = logits_spec
        impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits,
                                                                mask,
                                                                scale_factor=scale_factor)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                                softmax_out,
                                                                scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[2]
        k_seqlen = logits_aval.shape[3]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
                                                  scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1571
1572
1573

        return out

1574
1575
1576
1577
1578
1579
1580
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor)
1581

1582
1583
1584
1585
1586
1587
1588
1589
    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)
1590

1591
1592
1593
1594
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1595

1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                                           scale_factor: float) -> jnp.ndarray:
1607
    """
1608
1609
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
1610
    """
1611
1612
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor)
1613
1614


1615
1616
@dataclass(frozen=True)
class FusedAttnHelper:
1617
    """
1618
    Helper for the fused attention backend
1619
    """
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700

    q_type: jnp.dtype
    kv_type: jnp.dtype
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
    max_seqlen_q: int
    max_seqlen_kv: int
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
                                                             jax_dtype_to_te_dtype(self.kv_type),
                                                             self.qkv_layout, self.attn_bias_type,
                                                             self.attn_mask_type,
                                                             self.dropout_probability,
                                                             self.max_seqlen_q, self.max_seqlen_kv,
                                                             self.head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
                f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(mask):
    """
    Generating cumsum seqlen for a batch
    """
    seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32)
    cu_seqlen = jnp.cumsum(seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))
    return cu_seqlen


class SelfFusedAttnFwdPrimitive(BasePrimitive):
    """
    Self Fused Attention Forward Primitive
    """
    name = "te_self_fused_attn_forward"
1701
    multiple_results = True
1702
1703
1704
    impl_static_args = (4, 5, 6, 7, 8)
    inner_primitive = None
    outer_primitive = None
1705
1706

    @staticmethod
1707
1708
    def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
1709
        """
1710
        Self fused attention fwd abstract
1711
        """
1712
1713
1714
1715
1716
1717
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del mask_or_cu_seqlen_aval, scaling_factor, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
        *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
        assert nqkv == 3
        assert qkv_aval.dtype == bias_aval.dtype
1718

1719
1720
        output_shape = (*batch_shape, max_seqlen, num_head, head_dim)
        output_dtype = qkv_dtype
1721

1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
        backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
                                  attn_mask_type, dropout_probability, max_seqlen, max_seqlen,
                                  head_dim).get_fused_attn_backend()

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
            softmax_dtype = qkv_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1)
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
            raise ValueError(f'Not supported {backend=}')

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype

        out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype)
        rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
        return out_aval, softmax_aux_aval, rng_state_aval
1745
1746

    @staticmethod
1747
1748
    def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1749
        """
1750
        Self fused attention fwd lowering rules
1751
        """
1752
        qkv_aval, _, _, _ = ctx.avals_in
1753

1754
1755
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1756

1757
1758
        operands = [qkv, bias, cu_seqlen, seed]
        operand_shapes = map(lambda x: x.type.shape, operands)
1759
        out_types = [
1760
1761
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1762
        ]
1763

1764
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
1765
1766
1767
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1768

1769
        out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
1770

1771
1772
1773
1774
1775
1776
        return out

    @staticmethod
    def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
             dropout_probability, is_training):
        assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
1777

1778
        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1779

1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
        output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
            qkv,
            bias,
            cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return output, softmax_aux, rng_state
1791

1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnFwdPrimitive.outer_primitive is not None
        qkv, bias, cu_seqlen, seed = batched_args
        qkv_bdim, _, _, seed_bdim = batch_dims

        out_bdims = qkv_bdim, qkv_bdim, seed_bdim
        return SelfFusedAttnFwdPrimitive.outer_primitive.bind(
            qkv,
            bias,
            cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1811

1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
1824

1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
        impl = partial(SelfFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnFwdPrimitive)


def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray,
                        seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
1852
    """
1853
1854
    Wrapper for TE self fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
1855
    """
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
    return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
                                                          bias,
                                                          squeezed_mask,
                                                          seed,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
1871
1872


1873
class SelfFusedAttnBwdPrimitive(BasePrimitive):
1874
    """
1875
    Self Fused Attention Backward Primitive
1876
    """
1877
    name = "te_self_fused_attn_backward"
1878
    multiple_results = True
1879
1880
1881
    impl_static_args = (6, 7, 8, 9, 10)
    inner_primitive = None
    outer_primitive = None
1882
1883

    @staticmethod
1884
1885
1886
    def abstract(qkv_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
                 mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1887
        """
1888
        Self fused attention bwd abstract
1889
        """
1890
1891
1892
1893
1894
1895
1896
        del softmax_aux_aval, rng_state_aval
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del mask_or_cu_seqlen_aval, attn_mask_type
        del scaling_factor, dropout_probability, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
        assert qkv_aval.dtype == output_aval.dtype == doutput_aval.dtype
        *batch_shape, max_seqlen, num_head, _ = output_aval.shape
1897

1898
1899
1900
1901
1902
        if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
            bias_shape = (0,)
        else:
            bias_shape = (*batch_shape[:-1], 1, num_head, max_seqlen, max_seqlen)
        bias_dtype = qkv_dtype
1903

1904
1905
1906
        dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
        dbias = qkv_aval.update(shape=bias_shape, dtype=bias_dtype)
        return dqkv_aval, dbias
1907
1908

    @staticmethod
1909
1910
    def lowering(ctx, qkv, softmax_aux, rng_state, output, doutput, cu_seqlen, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
1911
        """
1912
        Self fused attention bwd lowering rules
1913
        """
1914
        qkv_aval, _, _, _, _, _ = ctx.avals_in
1915

1916
1917
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1918

1919
1920
        operands = [qkv, softmax_aux, rng_state, output, doutput, cu_seqlen]
        operand_shapes = map(lambda x: x.type.shape, operands)
1921
        out_types = [
1922
1923
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1924
        ]
1925

1926
1927
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1928
1929
1930
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1931

1932
        out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
1933
1934
1935

        return out

1936
1937
1938
1939
1940
1941
    @staticmethod
    def impl(qkv, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
             attn_mask_type, scaling_factor, dropout_probability, is_training):
        assert SelfFusedAttnBwdPrimitive.inner_primitive is not None

        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1942

1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
        dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
            qkv,
            softmax_aux,
            rng_state,
            output,
            doutput,
            cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return dqkv, dbias
1956

1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnBwdPrimitive.outer_primitive is not None
        qkv, softmax_aux, rng_state, output, doutput, cu_seqlen = batched_args
        qkv_bdim, *_ = batch_dims

        out_bdims = qkv_bdim, qkv_bdim
        return SelfFusedAttnBwdPrimitive.outer_primitive.bind(
            qkv,
            softmax_aux,
            rng_state,
            output,
            doutput,
            cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1978

1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_mask_type, scaling_factor, dropout_probability,
        del is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        dbias_spec = [None]
        if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
            dbias_spec = [*x_spec[:-5], None, x_spec[-2], None, None]
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_spec))
        return (dx_sharding, dbias_sharding)

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        dbias_spec = [None]
        if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
            dbias_spec = [*x_spec[:-5], None, x_spec[-2], None, None]
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dx_sharding, dbias_sharding)

        def sharded_impl(qkv, softmax_aux, rng_state, output, doutput, cu_seqlen):
            local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl(
                qkv,
                softmax_aux,
                rng_state,
                output,
                doutput,
                cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dx, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnBwdPrimitive)


def self_fused_attn_bwd(qkv: jnp.ndarray, softmax_aux: jnp.ndarray, rng_state: jnp.ndarray,
                        output: jnp.ndarray, doutput: jnp.ndarray, squeezed_mask: jnp.ndarray,
                        attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
                        scaling_factor: float, dropout_probability: float, is_training: bool):
2034
    """
2035
2036
    Wrapper for TE self fused attention bwd
    Return the gradients of self fused attention with packed qkv input
2037
    """
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
    return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
                                                          softmax_aux,
                                                          rng_state,
                                                          output,
                                                          doutput,
                                                          squeezed_mask,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
2049
2050


2051
class CrossFusedAttnFwdPrimitive(BasePrimitive):
2052
    """
2053
    Cross Fused Attention Forward Primitive
2054
    """
2055
    name = "te_cross_fused_attn_forward"
2056
    multiple_results = True
2057
2058
2059
    impl_static_args = (5, 6, 7, 8, 9)
    inner_primitive = None
    outer_primitive = None
2060
2061

    @staticmethod
2062
2063
    def abstract(q_aval, kv_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval, seed_aval, *,
                 attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
2064
        """
2065
        Cross fused attention fwd abstract
2066
        """
2067
2068
        del seed_aval, attn_bias_type, attn_mask_type
        del scaling_factor, dropout_probability, is_training
2069

2070
2071
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
2072

2073
2074
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
        *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
2075

2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
        assert q_dtype == kv_dtype
        assert q_batch_shape == kv_batch_shape
        assert q_num_head == kv_num_head
        assert q_head_dim == kv_head_dim
        assert nkv == 2
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype

        output_shape = q_aval.shape
        output_dtype = q_dtype
        softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
        softmax_aux_dtype = q_dtype

        out_aval = q_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
        return out_aval, softmax_aux_aval
2092
2093

    @staticmethod
2094
2095
    def lowering(ctx, q, kv, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, attn_mask_type,
                 scaling_factor, dropout_probability, is_training):
2096
        """
2097
        Cross fused attention fwd lowering rules
2098
        """
2099
2100
        q_aval, kv_aval, _, _, _ = ctx.avals_in
        assert q_aval.dtype == kv_aval.dtype
2101

2102
2103
2104
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2105

2106
2107
        operands = [q, kv, q_cu_seqlen, kv_cu_seqlen, seed]
        operand_shapes = map(lambda x: x.type.shape, operands)
2108
        out_types = [
2109
2110
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2111
2112
        ]

2113
2114
2115
2116
2117
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2118

2119
        out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
2120
2121
2122

        return out

2123
2124
2125
2126
    @staticmethod
    def impl(q, kv, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
             scaling_factor, dropout_probability, is_training):
        assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
2127

2128
2129
        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
2130

2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
        output, softmax_aux = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
            q,
            kv,
            q_cu_seqlen,
            kv_cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return output, softmax_aux
2143

2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
        q, kv, q_cu_seqlen, kv_cu_seqlen, seed = batched_args
        q_bdim, *_ = batch_dims

        out_bdims = q_bdim, q_bdim
        return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
            q,
            kv,
            q_cu_seqlen,
            kv_cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
        return (out_sharding, softmax_aux_sharding)

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
        seed_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding)
        impl = partial(CrossFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(CrossFusedAttnFwdPrimitive)


def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
                         kv_squeezed_mask: jnp.ndarray, seed: jnp.ndarray,
                         attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type,
                         scaling_factor: float, dropout_probability: float, is_training: bool):
2206
    """
2207
2208
    Wrapper for TE cross fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2209
    """
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
                                                           kv,
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           seed,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2223
2224


2225
class CrossFusedAttnBwdPrimitive(BasePrimitive):
2226
    """
2227
    Cross Fused Attention Backward Primitive
2228
    """
2229
    name = "te_cross_fused_attn_backward"
2230
    multiple_results = True
2231
2232
2233
    impl_static_args = (6, 7, 8, 9, 10)
    inner_primitive = None
    outer_primitive = None
2234
2235

    @staticmethod
2236
2237
2238
    def abstract(q_aval, kv_aval, softmax_aux_aval, doutput_aval, q_cu_seqlen_aval,
                 kv_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
2239
        """
2240
        Cross fused attention bwd abstract
2241
        """
2242
2243
2244
2245
2246
2247
2248
2249
2250
        del attn_bias_type, attn_mask_type
        del scaling_factor, dropout_probability, is_training
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
        softmax_aux_dtype = dtypes.canonicalize_dtype(softmax_aux_aval.dtype)
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
        assert q_dtype == kv_dtype == softmax_aux_dtype == doutput_dtype
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
2251

2252
2253
2254
        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
        return dq_aval, dkv_aval
2255
2256

    @staticmethod
2257
2258
    def lowering(ctx, q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2259
        """
2260
        Cross fused attention bwd lowering rules
2261
        """
2262
2263
        q_aval, kv_aval, _, _, _, _ = ctx.avals_in
        assert q_aval.dtype == kv_aval.dtype
2264

2265
2266
2267
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2268

2269
2270
        operands = [q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen]
        operand_shapes = map(lambda x: x.type.shape, operands)
2271
        out_types = [
2272
2273
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2274
        ]
2275

2276
2277
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2278
2279
2280
2281
2282
2283
        # the dropout elements are encoded in the forward auxiliary tensor
        # so seed is not needed in backward
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2284

2285
        out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
2286
2287
2288

        return out

2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
    @staticmethod
    def impl(q, kv, softmax_aux, doutput, q_squeezed_mask, kv_squeezed_mask, attn_bias_type,
             attn_mask_type, scaling_factor, dropout_probability, is_training):
        assert CrossFusedAttnBwdPrimitive.inner_primitive is not None

        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)

        dq, dkv = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
            q,
            kv,
            softmax_aux,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return dq, dkv

    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
        q, kv, softmax_aux, doutput, q_cu_seqlen, kv_cu_seqlen = batched_args
        q_bdim, kv_bdim, *_ = batch_dims

        out_bdims = q_bdim, kv_bdim
        return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
            q,
            kv,
            softmax_aux,
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
        return (dq_sharding, dkv_sharding)

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dq_sharding, dkv_sharding)

        impl = partial(CrossFusedAttnBwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)

        return mesh, impl, out_shardings, arg_shardings

2365

2366
register_primitive(CrossFusedAttnBwdPrimitive)
2367
2368


2369
2370
2371
2372
2373
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, softmax_aux: jnp.ndarray,
                         doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
                         kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2374
    """
2375
2376
    Wrapper for TE cross fused attention bwd
    Return the gradients of cross fused attention with packed kv input
2377
    """
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
    return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
                                                           kv,
                                                           softmax_aux,
                                                           doutput,
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2389
2390


2391
class GatedGeluPrimitive(BasePrimitive):
2392
    """
2393
    Gated Gelu Froward Primitive
2394
    """
2395
    name = "te_gated_gelu"
2396
    multiple_results = False
2397
2398
2399
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2400
2401

    @staticmethod
2402
    def abstract(x_aval):
2403
        """
2404
        gated_gelu abstract
2405
        """
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        x_shape = x_aval.shape
        assert x_shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        x_shape = x_aval.shape
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)
2416

2417
        return out_aval
2418
2419

    @staticmethod
2420
    def lowering(ctx, x):
2421
        """
2422
        gated_gelu lowering rules
2423
        """
2424
2425
2426
2427
2428
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
2429

2430
2431
2432
2433
2434
2435
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
        ]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
2436

2437
2438
2439
2440
2441
        hidden_size = ir_x_shape[-1]
        batch_size = reduce(operator.mul, ir_x_shape[:-2])
        in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
                                                               in_dtype)
2442

2443
        out = custom_caller(GatedGeluPrimitive.name, args, opaque, False)
2444

2445
        return [out]
2446

2447
2448
2449
2450
2451
    @staticmethod
    def impl(x):
        assert GatedGeluPrimitive.inner_primitive is not None
        out = GatedGeluPrimitive.inner_primitive.bind(x)
        return out
2452

2453
2454
2455
2456
2457
2458
2459
2460
2461
    @staticmethod
    def batcher(batched_args, batch_dims):
        """
        gated_gelu batcher
        """
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluPrimitive.outer_primitive is not None
        inputs, = batched_args
        inputs_bdim, = batch_dims
2462

2463
2464
        out_bdims = inputs_bdim
        return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims
2465

2466
2467
2468
2469
2470
2471
2472
2473
2474
    @staticmethod
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
        """
        gated_gelu infer_sharding_from_operands
        """
        del result_infos    # Unused.
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding
2475

2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        gated_gelu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        impl = GatedGeluPrimitive.impl
        return mesh, impl, out_sharding, arg_shardings
2487
2488


2489
register_primitive(GatedGeluPrimitive)
2490
2491


2492
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray:
2493
    """
2494
2495
2496
    gated gelu wrapper
    Return FP8(geglu(inputs))
    Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
2497
    """
2498
    return GatedGeluPrimitive.outer_primitive.bind(inputs)
2499
2500


2501
class DgatedGeluPrimitive(BasePrimitive):
2502
    """
2503
    Dgated Gelu Primitive
2504
    """
2505
2506
2507
2508
2509
    name = "te_dgated_gelu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2510
2511

    @staticmethod
2512
    def abstract(dz_aval, x_aval):
2513
        """
2514
        dgated_gelu abstract
2515
        """
2516
2517
2518
2519
2520
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
2521

2522
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
2523

2524
2525
2526
2527
2528
        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)
        return out_aval
2529
2530

    @staticmethod
2531
    def lowering(ctx, dz, x):
2532
        """
2533
        dgated_gelu lowering rules
2534
        """
2535
2536
2537
2538
2539
2540
2541
2542
2543
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
        ir_in_type = ir.RankedTensorType(dz.type)
        ir_in_shape = ir_in_type.shape
        gi_type = ir.RankedTensorType(x.type)
        gi_shape = gi_type.shape
        for axis in range(len(ir_in_shape) - 1):
            assert ir_in_shape[axis] == gi_shape[axis]
2544

2545
2546
2547
2548
2549
2550
        ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
        i_hidden_size = ir_in_shape[-1]
        g_hidden_size = gi_shape[-1]
        assert i_hidden_size == g_hidden_size
        out_dtype = ir_in_type.element_type
        out_shape = gi_shape
2551
2552

        out_types = [
2553
            ir.RankedTensorType.get(out_shape, out_dtype),
2554
        ]
2555
2556
        operands = [dz, x]
        operand_shapes = [ir_in_shape, gi_shape]
2557
2558
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2559
2560
2561
        in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
                                                               in_dtype, in_dtype)
2562

2563
        out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False)
2564
2565
2566
2567

        return [out]

    @staticmethod
2568
2569
2570
2571
2572
2573
2574
    def impl(dz, x):
        """
        dgated_gelu implementation
        """
        assert DgatedGeluPrimitive.inner_primitive is not None
        dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x)
        return dx
2575
2576

    @staticmethod
2577
    def batcher(batched_args, batch_dims):
2578
        """
2579
        dgated_gelu batcher
2580
        """
2581
2582
2583
2584
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims
2585

2586
2587
        out_bdims = x_bdim
        return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
2588
2589

    @staticmethod
2590
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
2591
        """
2592
        dgated_gelu infer_sharding_from_operands
2593
        """
2594
2595
2596
2597
        del result_infos    # Unused.
        gelu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
        return dx_sharding
2598

2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        dgated_gelu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding
        impl = DgatedGeluPrimitive.impl
        return mesh, impl, out_shardings, arg_shardings
2610
2611


2612
register_primitive(DgatedGeluPrimitive)
2613
2614


2615
2616
2617
2618
2619
2620
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
    """
    dgated_gelu fusion wrapper
    Return dgeglu(inputs)
    """
    return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
2621
2622


2623
2624
def _normalize_axis_boundary(axis, ndim):
    return axis if axis >= 0 else ndim + axis
2625
2626


2627
def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
2628
    """
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
2647
    """
2648
2649
2650
2651
2652
2653
2654
2655
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes
    assert static_axis_boundary < len(shape) - 2    # at least 2 remaining for transpose.
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
    return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:],
            *shape[transpose_start_idx:transpose_axis_boundary])
2656
2657


2658
class CastTransposePrimitive(BasePrimitive):
2659
    """
2660
    Cast Transpose Primitive
2661
    """
2662
2663
2664
2665
2666
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None
2667
2668

    @staticmethod
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval
2688
2689

    @staticmethod
2690
2691
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
2692
        """
2693
        te_cast_transpose_p lowering rules
2694
        """
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(CastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 2})

        return out
2737
2738

    @staticmethod
2739
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
2740
        """
2741
        te_cast_transpose implementation
2742
        """
2743
2744
2745
2746
2747
2748
2749
        assert CastTransposePrimitive.inner_primitive is not None
        casted_x, casted_transposed_x, updated_amax = \
            CastTransposePrimitive.inner_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary)
        return casted_x, casted_transposed_x, updated_amax
2750

2751
2752
2753
2754
2755
2756
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
                transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2757

2758
2759
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims
2760

2761
2762
2763
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2764

2765
2766
2767
2768
2769
2770
2771
2772
2773
        out_bdims = x_bdim, x_bdim, amax_bdim
        return CastTransposePrimitive.outer_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2774

2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
                                     arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                  result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
            local_cx, local_cxt, local_updated_amax = \
                CastTransposePrimitive.impl(x, amax, scale, scale_inv,
                     out_dtype=out_dtype,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype, static_axis_boundary: int,
                   transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
2817
    """
2818
2819
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
2820
    """
2821
2822
2823
2824
2825
2826
2827
2828
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary)
2829
2830


2831
class TransposePrimitive(BasePrimitive):
2832
    """
2833
    Transpose Primitive
2834
    """
2835
    name = "te_transpose"
2836
    multiple_results = False
2837
2838
2839
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None
2840
2841

    @staticmethod
2842
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
2843
        """
2844
        _transpose abstract
2845
        """
2846
2847
2848
        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
2849

2850
        return xt_aval
2851
2852

    @staticmethod
2853
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
2854
        """
2855
        _transpose cuda lowering
2856
2857
        """

2858
2859
2860
2861
        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
            jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2
        ]
2862

2863
2864
2865
2866
2867
2868
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
2869

2870
2871
2872
2873
2874
2875
        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
        operands = [x]
        operand_shapes = [ir_x_shape]
2876
2877
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2878
2879
2880
2881
2882
        te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype,
                                                               te_dtype)
2883

2884
        out = custom_caller(TransposePrimitive.name, args, opaque, False)
2885
2886
2887

        return [out]

2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
        transposed_x = \
            TransposePrimitive.inner_primitive.bind(x,
                                                    static_axis_boundary=static_axis_boundary,
                                                    transpose_axis_boundary=transpose_axis_boundary)
        return transposed_x
2899

2900
2901
2902
2903
2904
    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2905

2906
2907
        x, = batched_args
        x_bdim, = batch_dims
2908

2909
2910
2911
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2912

2913
2914
2915
2916
        out_bdims = x_bdim
        return TransposePrimitive.outer_primitive.bind(
            x, static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2917
2918

    @staticmethod
2919
2920
2921
2922
2923
2924
2925
    def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding
2926
2927

    @staticmethod
2928
2929
2930
2931
2932
2933
2934
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding
2935

2936
2937
2938
        impl = partial(TransposePrimitive.impl,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
2939

2940
        return mesh, impl, out_shardings, arg_shardings
2941
2942


2943
register_primitive(TransposePrimitive)
2944
2945


2946
2947
def transpose(x: jnp.ndarray, static_axis_boundary: int,
              transpose_axis_boundary: int) -> jnp.ndarray:
2948
    """
2949
    transpose wrapper
2950
    """
2951
2952
2953
    return TransposePrimitive.outer_primitive.bind(x,
                                                   static_axis_boundary=static_axis_boundary,
                                                   transpose_axis_boundary=transpose_axis_boundary)
2954
2955


2956
class LayerNormFwdFp8Primitive(BasePrimitive):
2957
    """
2958
    Layer Normalization Forward FP8 Primitive
2959
    """
2960
2961
2962
2963
2964
    name = "te_layernorm_forward_fp8"
    multiple_results = True
    impl_static_args = (6, 7, 8)    # out_type, zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
2965
2966

    @staticmethod
2967
2968
    def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 zero_centered_gamma, epsilon):
2969
        """
2970
        LayerNorm fwd (fp8 out) abstract
2971
        """
2972
2973
        del zero_centered_gamma, epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
2974

2975
2976
2977
2978
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
2979

2980
2981
2982
2983
2984
2985
2986
2987
2988
        mu_rsigama_dtype = jnp.float32

        assert gamma_aval.size == beta_aval.size

        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, mu_aval, rsigma_aval, updated_amax_aval
2989
2990

    @staticmethod
2991
2992
    def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma,
                 epsilon):
2993
        """
2994
        LayerNorm fwd (fp8 out) lowering rules
2995
        """
2996
        x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
2997

2998
2999
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3000

3001
3002
3003
3004
3005
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gamma_aval.dtype == beta_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3006

3007
3008
3009
3010
3011
3012
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
3013

3014
3015
        assert g_type == b_type
        assert g_shape == b_shape
3016

3017
3018
3019
3020
3021
3022
3023
3024
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3025

3026
3027
3028
3029
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3030

3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, beta, amax, scale, scale_inv]
        operand_shapes = [
            x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3042

3043
3044
3045
3046
3047
3048
3049
3050
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
        )
3051

3052
3053
3054
3055
3056
        out = custom_caller(LayerNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={3: 3})
3057

3058
        return out
3059
3060

    @staticmethod
3061
    def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon):
3062
        """
3063
        to describe implementation
3064
        """
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
        assert LayerNormFwdFp8Primitive.inner_primitive is not None
        out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
        return out, mu, rsigma, updated_amax
3077
3078

    @staticmethod
3079
    def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon):
3080
        """
3081
        to describe batch rules for vmap
3082
        """
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, beta, amax, scale, scale_inv = batched_args
        x_bdim, _, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
        return LayerNormFwdFp8Primitive.outer_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos,
                                     result_infos):
        del out_dtype, zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance.")

        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
3136

3137
3138
3139
3140
3141
3142
3143
        def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
            local_x, local_mu, local_rsigma, local_amax = \
                LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv,
                                            out_dtype=out_dtype,
                                            zero_centered_gamma=zero_centered_gamma,
                                            epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3144

3145
            return local_x, local_mu, local_rsigma, global_updated_amax
3146

3147
        return mesh, sharded_impl, out_shardings, arg_shardings
3148

3149
3150
3151
3152
3153
3154
3155

register_primitive(LayerNormFwdFp8Primitive)


def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
                      scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype,
                      zero_centered_gamma: bool, epsilon: float):
3156
    """
3157
    Wrapper for TE layernorm fwd (fp8 out)
3158
    """
3159
3160
3161
3162
3163
3164
3165
3166
3167
    return LayerNormFwdFp8Primitive.outer_primitive.bind(x,
                                                         gamma,
                                                         beta,
                                                         amax,
                                                         scale,
                                                         scale_inv,
                                                         out_dtype=out_dtype,
                                                         zero_centered_gamma=zero_centered_gamma,
                                                         epsilon=epsilon)
3168
3169


3170
class RmsNormFwdFp8Primitive(BasePrimitive):
3171
    """
3172
    RMS Normalization Forward FP8 Primitive
3173
    """
3174
3175
3176
3177
3178
    name = "te_rmsnorm_forward_fp8"
    multiple_results = True
    impl_static_args = (5, 6)    # out_dtype, epsilon
    inner_primitive = None
    outer_primitive = None
3179

3180
3181
    @staticmethod
    def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
3182
        """
3183
        RMSNorm fwd (fp8 out) abstract
3184
        """
3185
3186
        del epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3187

3188
3189
3190
3191
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3192

3193
3194
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
3195

3196
        rsigama_dtype = jnp.float32
3197

3198
3199
3200
        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
        amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3201

3202
        return out_aval, rsigma_aval, amax_aval
3203
3204

    @staticmethod
3205
    def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
3206
        """
3207
        RMSNorm fwd (fp8 out) lowering rules
3208
3209
        """

3210
3211
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3212

3213
        x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3214

3215
3216
3217
3218
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3219

3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3237

3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, amax, scale, scale_inv]
        operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
3254
3255
        )

3256
3257
3258
3259
3260
3261
3262
3263
        out = custom_caller(RmsNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out

3264
    @staticmethod
3265
    def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
3266
        """
3267
        to describe implementation
3268
        """
3269
3270
3271
3272
3273
3274
3275
3276
3277
        assert RmsNormFwdFp8Primitive.inner_primitive is not None
        out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
                                                                        gamma,
                                                                        amax,
                                                                        scale,
                                                                        scale_inv,
                                                                        out_dtype=out_dtype,
                                                                        epsilon=epsilon)
        return out, rsigma, amax
3278

3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
        """
        to describe batch rules for vmap
        """
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
        out_bdims = x_bdim, x_bdim, amax_bdim
        return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                           gamma,
                                                           amax,
                                                           scale,
                                                           scale_inv,
                                                           out_dtype=out_dtype,
                                                           epsilon=epsilon), out_bdims
3296

3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
    @staticmethod
    def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del out_dtype, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, rsigma_sharding, amax_sharding)
3311

3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
    @staticmethod
    def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
3330

3331
3332
3333
3334
3335
        def sharded_impl(x, gamma, amax, scale, scale_inv):
            local_x, local_rsigma, local_amax= \
                RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv,
                                            out_dtype=out_dtype, epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3336

3337
            return local_x, local_rsigma, global_updated_amax
3338

3339
        return mesh, sharded_impl, out_shardings, arg_shardings
3340
3341


3342
register_primitive(RmsNormFwdFp8Primitive)
3343

3344
3345
3346

def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                    scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float):
3347
    """
3348
    Wrapper for TE rmsnorm fwd (fp8 out)
3349
    """
3350
3351
3352
3353
3354
3355
3356
    return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                       gamma,
                                                       amax,
                                                       scale,
                                                       scale_inv,
                                                       out_dtype=out_dtype,
                                                       epsilon=epsilon)
3357
3358


3359
class GatedGeluFp8Primitive(BasePrimitive):
3360
    """
3361
    Gated Gelu FP8 Primitive
3362
    """
3363
    name = "te_gated_gelu_fp8"
3364
    multiple_results = True
3365
3366
3367
    impl_static_args = (4,)    #out_dtype
    inner_primitive = None
    outer_primitive = None
3368
3369

    @staticmethod
3370
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
3371
        """
3372
        te_gated_gelu_p abstract
3373
        """
3374
3375
3376
3377
3378
3379
3380
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3381

3382
3383
3384
3385
3386
3387
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3388

3389
        return out_aval, updated_amax_aval
3390
3391

    @staticmethod
3392
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
3393
        """
3394
        te_gated_gelu_p lowering rules
3395
        """
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3409

3410
3411
3412
3413
        hidden_size = ir_x_shape[-1]
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
3414
        out_types = [
3415
3416
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
3417
        ]
3418
3419
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
3420
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3421

3422
3423
3424
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))
3425

3426
3427
3428
3429
3430
        out = custom_caller(GatedGeluFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 1})
3431
3432
3433
3434

        return out

    @staticmethod
3435
    def impl(x, amax, scale, scale_inv, out_dtype):
3436
        """
3437
        to describe implementation
3438
        """
3439
3440
3441
3442
3443
3444
3445
        assert GatedGeluFp8Primitive.inner_primitive is not None
        out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x,
                                                                       amax,
                                                                       scale,
                                                                       scale_inv,
                                                                       out_dtype=out_dtype)
        return out, updated_amax
3446
3447

    @staticmethod
3448
    def batcher(batched_args, batch_dims, *, out_dtype):
3449
        """
3450
        to describe batch rules for vmap
3451
        """
3452
3453
3454
3455
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims
3456

3457
3458
3459
3460
3461
3462
        out_bdims = x_bdim, amax_bdim
        return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                          amax,
                                                          scale,
                                                          scale_inv,
                                                          out_dtype=out_dtype), out_bdims
3463

3464
3465
3466
3467
3468
3469
3470
    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)
3471

3472
3473
3474
3475
3476
3477
3478
3479
    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)
3480

3481
3482
3483
3484
3485
3486
3487
        def sharded_impl(x, amax, scale, scale_inv):
            local_x, local_amax = GatedGeluFp8Primitive.impl(x,
                                                             amax,
                                                             scale,
                                                             scale_inv,
                                                             out_dtype=out_dtype)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3488

3489
            return local_x, global_updated_amax
3490

3491
        return mesh, sharded_impl, out_shardings, arg_shardings
3492
3493


3494
register_primitive(GatedGeluFp8Primitive)
3495

3496
3497
3498

def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3499
    """
3500
3501
    gated gelu wrapper
    Return FP8(geglu(x))
3502
    """
3503
3504
3505
3506
3507
    return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                      amax,
                                                      scale,
                                                      scale_inv,
                                                      out_dtype=out_dtype)
3508
3509


3510
class DgatedGeluCastTransposePrimitive(BasePrimitive):
3511
    """
3512
    Dgated Gelu Cast Transpose Primitive
3513
    """
3514
    name = "te_dgated_gelu_cast_transpose"
3515
    multiple_results = True
3516
3517
3518
    impl_static_args = (5, 6)    # out_dtype, static_axis_boundary
    inner_primitive = None
    outer_primitive = None
3519
3520

    @staticmethod
3521
3522
    def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 static_axis_boundary):
3523
        """
3524
        te_dgated_gelu_cast_transpose_p abstract
3525
        """
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert x_aval.shape[-2] == 2    # Linear + GeLU
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval
3541

3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
        """
        te_dgated_gelu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        assert x_shape[-2] == 2    # Linear + GeLU
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(dz_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(DgatedGeluCastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out
3591
3592

    @staticmethod
3593
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
3594
        """
3595
        to describe implementation
3596
        """
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
        assert DgatedGeluCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary)
        return out, t_out, updated_amax
3607

3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
3618

3619
3620
3621
3622
        out_bdims = x_bdim, x_bdim, amax_bdim
        return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
            dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
            static_axis_boundary=x_bdim), out_bdims
3623

3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)
3634

3635
3636
3637
3638
3639
3640
3641
    @staticmethod
    def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
3642

3643
3644
3645
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
3646

3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_updated_amax
3658

3659
        return mesh, sharded_impl, out_shardings, arg_shardings
3660
3661


3662
register_primitive(DgatedGeluCastTransposePrimitive)
3663

3664
3665
3666
3667
3668

def dgated_gelu_cast_transpose(
        dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
        scale_inv: jnp.ndarray, out_dtype: TEDType,
        static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3669
    """
3670
3671
    cast transpose d_gated_gelu fusion wrapper
    Return FP8(dgeglu(inputs))
3672
    """
3673
3674
3675
3676
3677
3678
3679
3680
    return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary)