cpp_extensions.py 149 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
#
# See LICENSE for license information.
"""JAX te custom call"""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
10
import os
11
12
import warnings

13
14
15
16
17
18
19
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend

20
21
22
23
24
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.interpreters import xla, mlir
25
from jax.experimental.custom_partitioning import custom_partitioning
26
from jax.interpreters.mlir import ir, dtype_to_ir_type
27
28
from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
29
30
31
32
33
34
from jax._src import dispatch

from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec
35

36
37
38
39
40
41
42
try:
    from jaxlib.hlo_helpers import custom_call
except ImportError:
    # Newer JAX changed its API. But we want to support a few JAX
    # version, so we still need this import.
    pass

43
44
45
46
47
48
49
50
51
for _name, _value in transformer_engine_jax.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="CUDA")


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)
67
68
69
70
71
72
73
74
75


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


76
77
78
79
80
81
82
def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


83
84
85
86
def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
87
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
88

89
90
91
92
93
94
95
96
97
    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
    }
98

99
100
    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")
101

102
    return converter.get(jax_dtype)
103
104


105
106
107
108
109
110
111
112
def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)
113
114


115
def _check_valid_batch_dims(bdims):
116
    """
117
    Assert out non-supported bath dims
118
    """
119
120
121
122
    for dim in bdims:
        assert dim in [0, None], \
            "Currently only support batch_dim in [0, None], " \
            f"but got {dim=}"
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145


class BasePrimitive(metaclass=ABCMeta):
    """
    jax premitive
    """

    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

178
179
180
181
182

def register_primitive(cls):
    """
    register jax primitive
    """
183
184
185
186
187

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
188
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
189
190
191
192
193
194
195
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
196
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
197
198
199
200
201
202
203
204
205
206
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253


@dataclass
class CustomCallArgsWrapper:
    """
    wrapper of XLA custom call args
    """

    def __init__(self,
                 output_types,
                 operands,
                 operand_shapes,
                 operand_specific_layouts=None,
                 output_specific_layouts=None):
        self.output_types = output_types
        self.operands = operands
        self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
                                                                      operand_specific_layouts)
        output_shapes = [x.shape for x in output_types]
        self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
                                                                     output_specific_layouts)

    @staticmethod
    def generate_layouts(shapes, specific_layouts):
        """
        setup layouts for XLA custom call
        """

        def default_layout(shape):
            return range(len(shape) - 1, -1, -1)

        if specific_layouts is None:
            specific_layouts = {}

        layouts = []
        for idx, shape in enumerate(shapes):
            if idx in specific_layouts:
                layouts.append(specific_layouts[idx])
            else:
                layouts.append(default_layout(shape))
        return layouts


def custom_caller(name, args, opaque, has_side_effect, **kwargs):
    """
    XLA custom call warpper
    """
254
255
256
257
258
259
260
261
262
263
264
265
266
    if hasattr(mlir, "custom_call"):
        out = mlir.custom_call(name,
                               result_types=args.output_types,
                               operands=args.operands,
                               operand_layouts=args.operand_layouts,
                               result_layouts=args.output_layouts,
                               backend_config=opaque,
                               has_side_effect=has_side_effect,
                               **kwargs).results
    else:
        # Need to disable one pylint error as the second function
        # parameter name recenctly in JAX. Otherwise we won't be
        # compatible with multiple JAX version.
267
268
269
270
271
272
273
274
275
        out = custom_call(    # pylint: disable=too-many-function-args
            name,
            args.output_types,
            operands=args.operands,
            operand_layouts=args.operand_layouts,
            result_layouts=args.output_layouts,
            backend_config=opaque,
            has_side_effect=has_side_effect,
            **kwargs)
276
277
278
    return out


279
class LayerNormFwdPrimitive(BasePrimitive):
280
    """
281
    Layer Normalization Forward Primitive
282
    """
283
284
285
286
287
    name = "te_layernorm_forward"
    multiple_results = True
    impl_static_args = (3, 4)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
288
289

    @staticmethod
290
    def abstract(x_aval, gamma_aval, beta_aval, **kwargs):    # pylint: disable=unused-argument
291
        """
292
        LayerNorm fwd abstract
293
        """
294
295
296
297
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        mu_rsigama_dtype = jnp.float32
298

299
300
        out_aval = core.raise_to_shaped(x_aval)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
301

302
303
304
305
306
        assert gamma_aval.size == beta_aval.size
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0

        return out_aval, mu_aval, rsigma_aval
307
308

    @staticmethod
309
    def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
310
        """
311
        LayerNorm fwd lowering rules
312
        """
313
314
315
316
317
318
319
320
        x_aval, gamma_aval, beta_aval = ctx.avals_in
        assert gamma_aval.dtype == beta_aval.dtype
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
321

322
323
        assert g_type == b_type
        assert g_shape == b_shape
324

325
326
327
328
329
        # Output shape is same as the input shape, but the output type is same as the weight type.
        # See ln_api.cpp
        output_type = g_type.element_type
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
330

331
332
333
334
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
335

336
337
338
339
340
341
342
343
        out_types = [
            ir.RankedTensorType.get(out_shape, output_type),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
        ]
        operands = [x, gamma, beta]
        operand_shapes = [x_shape, g_shape, b_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
344

345
346
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

347
348
349
350
351
352
353
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
354
            sm_margin,
355
        )
356

357
        out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
358

359
        return out
360
361

    @staticmethod
362
    def impl(x, gamma, beta, zero_centered_gamma, epsilon):
363
        """
364
        to describe implementation
365
        """
366
367
368
369
370
371
372
        assert LayerNormFwdPrimitive.inner_primitive is not None
        out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind(
            x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return out, mu, rsigma

    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
373
        """
374
        to describe batch rules for vmap
375
        """
376
377
378
379
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdPrimitive.outer_primitive is not None
        x, gamma, beta = batched_args
        x_bdim, _, _ = batch_dims
380

381
382
383
384
385
386
        out_bdims = x_bdim, x_bdim, x_bdim
        return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                          gamma,
                                                          beta,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
387

388
389
390
391
392
393
394
395
396
397
398
399
400
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, mu_sharding, rsigma_sharding)
401

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
417

418
419
420
421
422
423
        arg_shardings = (x_sharding, g_sharding, b_sharding)
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
        impl = partial(LayerNormFwdPrimitive.impl,
                       zero_centered_gamma=zero_centered_gamma,
                       epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
424
425


426
register_primitive(LayerNormFwdPrimitive)
427
428


429
430
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
                  epsilon: float):
431
    """
432
    Wrapper for TE layernorm fwd
433
    """
434
435
436
437
438
    return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                      gamma,
                                                      beta,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
439
440


441
class LayerNormBwdPrimitive(BasePrimitive):
442
    """
443
    Layer Normalization Backward Primitive
444
    """
445
446
447
448
449
    name = "te_layernorm_backward"
    multiple_results = True
    impl_static_args = (5, 6)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
450
451

    @staticmethod
452
    def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
453
        """
454
        Layernorm bwd abstract
455
        """
456
457
458
459
460
461
462
463
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
        assert mu_dtype == rsigma_dtype == jnp.float32
464

465
466
467
        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval, dbeta_aval
468
469

    @staticmethod
470
    def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
471
        """
472
        Layernorm bwd lowering rules
473
        """
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
        _, x_aval, _, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(gamma.type)
        b_shape = b_type.shape
        assert g_type == b_type
        assert g_shape == b_shape

        dz_shape = ir.RankedTensorType(dz.type).shape
        mu_shape = ir.RankedTensorType(mu.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
490
491

        out_types = [
492
493
494
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
            ir.RankedTensorType.get(b_shape, b_type.element_type),
495
        ]
496
497
        operands = [dz, mu, rsigma, x, gamma]
        operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
498
499
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

500
501
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

502
503
504
505
506
507
508
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
509
            sm_margin,
510
        )
511

512
        out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
513

514
        return out
515

516
517
518
519
520
521
    @staticmethod
    def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
        assert LayerNormBwdPrimitive.inner_primitive is not None
        dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind(
            dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return dx, dgamma, dbeta
522

523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert LayerNormBwdPrimitive.outer_primitive is not None
        dz, x, mu, rsigma, gamma = batched_args
        _, x_bdim, _, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim, gamma_bdim
        return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                          x,
                                                          mu,
                                                          rsigma,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
538

539
540
541
542
543
544
545
546
547
548
549
550
551
552
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        return dx_sharding, dgamma_sharding, dbeta_sharding
553

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
        arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec)))

        def sharded_impl(dz, x, mu, rsigma, gamma):
            local_dx, local_dgamma, local_dbeta = \
                LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma,
                     zero_centered_gamma=zero_centered_gamma,
                     epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
            return local_dx, global_dgamma, global_dbeta

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(LayerNormBwdPrimitive)


def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray,
                  gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
589
    """
590
    Wrapper for TE layernorm bwd
591
    """
592
593
594
595
596
597
598
    return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                      x,
                                                      mu,
                                                      rsigma,
                                                      gamma,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
599
600


601
class RmsNormFwdPrimitive(BasePrimitive):
602
    """
603
    RMS Normalization Forward Primitive
604
    """
605
    name = "te_rmsnorm_forward"
606
    multiple_results = True
607
608
609
    impl_static_args = (2,)    # epsilon
    inner_primitive = None
    outer_primitive = None
610
611

    @staticmethod
612
    def abstract(x_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
613
        """
614
        RMSNorm fwd abstract
615
        """
616
617
618
619
620
621
622
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        rsigama_dtype = jnp.float32

        out_aval = core.raise_to_shaped(x_aval)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
623

624
625
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
626

627
        return out_aval, rsigma_aval
628
629

    @staticmethod
630
    def lowering(ctx, x, gamma, *, epsilon):
631
        """
632
        RMSNorm fwd lowering rules
633
        """
634
635
636
637
638
639
640
641
642
643
644
        x_aval, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        rsigma_element_type = ir.F32Type.get()

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
645
646

        out_types = [
647
648
            ir.RankedTensorType.get(out_shape, x_type.element_type),
            ir.RankedTensorType.get(batch_shape, rsigma_element_type),
649
        ]
650
651
        operands = [x, gamma]
        operand_shapes = [x_shape, g_shape]
652
653
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

654
655
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

656
657
658
659
660
661
662
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
663
            sm_margin,
664
        )
665

666
        out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
667
668
669
670

        return out

    @staticmethod
671
    def impl(x, gamma, epsilon):
672
        """
673
        to describe implementation
674
        """
675
676
677
        assert RmsNormFwdPrimitive.inner_primitive is not None
        out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
        return out, rsigma
678
679

    @staticmethod
680
    def batcher(batched_args, batch_dims, *, epsilon):
681
        """
682
        to describe batch rules for vmap
683
        """
684
685
686
687
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdPrimitive.outer_primitive is not None
        x, gamma = batched_args
        x_bdim, _ = batch_dims
688

689
690
        out_bdims = x_bdim, x_bdim
        return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
691

692
693
694
695
696
697
698
699
700
701
702
703
704
    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, rsigma_sharding)
705

706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (x_sharding, g_sharding)
        out_shardings = (out_sharding, rsigma_sharding)
        impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
724
725


726
register_primitive(RmsNormFwdPrimitive)
727
728


729
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
730
    """
731
    Wrapper for TE rmsnorm fwd
732
    """
733
    return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
734
735


736
class RmsNormBwdPrimitive(BasePrimitive):
737
    """
738
    RMS Normalization Backward Primitive
739
    """
740
    name = "te_rmsnorm_backward"
741
    multiple_results = True
742
743
744
    impl_static_args = (4,)    # epsilon
    inner_primitive = None
    outer_primitive = None
745
746

    @staticmethod
747
748
749
750
751
752
753
    def abstract(
            dz_aval,
            x_aval,
            rsigma_aval,
            gamma_aval,
            **kwargs    # pylint: disable=unused-argument
    ):
754
        """
755
        RMSNorm bwd abstract
756
        """
757
758
759
760
761
762
763
764
765
766
767
768
769
770
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert rsigma_aval.shape == x_aval.shape[:-1]
        assert rsigma_dtype == jnp.float32

        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval

    @staticmethod
    def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
771
        """
772
        RMSNorm bwd lowering rules
773
        """
774
775
776
777
778
779
780
781
782
783
        _, x_aval, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        dz_shape = ir.RankedTensorType(dz.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
784
785

        out_types = [
786
787
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
788
        ]
789
790
        operands = [dz, rsigma, x, gamma]
        operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
791
792
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

793
794
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

795
796
797
798
799
800
801
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
802
            sm_margin,
803
        )
804

805
        out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
806
807
808

        return out

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    @staticmethod
    def impl(dz, x, rsigma, gamma, epsilon):
        assert RmsNormBwdPrimitive.inner_primitive is not None
        dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
        return dx, dgamma

    @staticmethod
    def batcher(batched_args, batch_dims, *, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert RmsNormBwdPrimitive.outer_primitive is not None
        dz, x, rsigma, gamma = batched_args
        _, x_bdim, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim
        return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma,
                                                        epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        return dx_sharding, dgamma_sharding

    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_shardings = dx_sharding, dgamma_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec)))

        def sharded_impl(dz, x, rsigma, gamma):
            local_dx, local_dgamma = \
                RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            return local_dx, global_dgamma

        return mesh, sharded_impl, out_shardings, arg_shardings

867

868
register_primitive(RmsNormBwdPrimitive)
869
870


871
872
def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray,
                epsilon: float):
873
    """
874
    Wrapper for TE layernorm bwd
875
    """
876
    return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
877
878


879
class SoftmaxPrimitive(BasePrimitive):
880
    """
881
    Softmax Primitive
882
    """
883
    max_k_seqlen_supported = 4096
884
885

    @staticmethod
886
887
888
889
890
    @abstractmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError
891

892
893
894
895
896
    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
        threads_per_block = 128    # Depends on the kernel implmentation
897

898
899
900
901
902
903
        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block
904
905

    @staticmethod
906
    def forward_abstract(logits_aval, scale_factor):
907
        """
908
        softmax_forward abstract
909
        """
910
911
912
913
914
915
916
917
918
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
919

920
921
        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval
922

923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
        i_aval, = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        pad_batch = batch
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits]
        operand_shapes = [i_shape]
941
942
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

943
944
945
946
        opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
                                                                k_seqlen,
                                                                jax_dtype_to_te_dtype(i_aval.dtype),
                                                                scale_factor)
947

948
        out = custom_caller(name, args, opaque, False)
949
950
951

        return [out]

952
953
954
955
956
957
958
959
    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output
960

961
962
963
964
965
966
967
968
    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
        logits, = batched_args
        logits_bdim, = batch_dims
969

970
971
        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims
972

973
974
975
976
977
978
979
980
981
    @staticmethod
    def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding
982
983

    @staticmethod
984
    def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
985
        """
986
        softmax_forward partitioning
987
        """
988
989
990
991
992
993
994
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        out_spec = logits_spec
        arg_shardings = (logits_spec,)
        out_shardings = out_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
995

996
997
998
999
1000
1001
1002
1003
1004
1005
    @staticmethod
    def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None):    # pylint: disable=unused-argument
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]
1006

1007
        assert dz_aval.shape == softmax_out_aval.shape
1008

1009
1010
        dx_aval = core.raise_to_shaped(softmax_out_aval)
        return dx_aval
1011
1012

    @staticmethod
1013
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
1014
        """
1015
        softmax_backward lowering rules
1016
        """
1017
        dz_aval, _ = ctx.avals_in
1018

1019
1020
        dz_type = ir.RankedTensorType(dz.type)
        dz_shape = dz_type.shape
1021

1022
1023
1024
1025
1026
1027
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, dz_shape[:-3])
        pad_batch = batch    # unused
        heads = dz_shape[-3]
        q_seqlen = dz_shape[-2]
        k_seqlen = dz_shape[-1]
1028

1029
1030
        softmax_out_type = ir.RankedTensorType(softmax_out.type)
        softmax_out_shape = softmax_out_type.shape
1031

1032
1033
1034
        out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)]
        operands = [dz, softmax_out]
        operand_shapes = [dz_shape, softmax_out_shape]
1035
1036
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1037
1038
1039
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype),
            scale_factor)
1040

1041
        out = custom_caller(name, args, opaque, False)
1042

1043
        return [out]
1044
1045

    @staticmethod
1046
    def backward_impl(primitive, dz, softmax_out, scale_factor):
1047
        """
1048
        softmax_backward implementation
1049
        """
1050
1051
1052
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx
1053

1054
1055
1056
1057
1058
1059
1060
1061
    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims
1062

1063
1064
        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
1065
1066

    @staticmethod
1067
    def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1068
        """
1069
        softmax_backward infer_sharding_from_operands
1070
        """
1071
1072
1073
1074
        del scale_factor, result_infos    # Unused.
        softmax_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec))
        return dx_sharding
1075

1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    @staticmethod
    def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos
        dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        dx_spec = softmax_out_spec
        arg_shardings = (dz_spec, softmax_out_spec)
        out_shardings = dx_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
1089
1090


1091
1092
1093
1094
1095
1096
1097
1098
1099
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
    name = "te_scaled_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None
1100

1101
1102
1103
1104
1105
    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads
1106

1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False
1118

1119
1120
1121
1122
1123
1124
    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)
1125

1126
1127
1128
1129
1130
1131
1132
1133
1134
    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits,
                                             scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive,
                                                batched_args,
                                                batch_dims,
                                                scale_factor=scale_factor)
1148

1149
1150
1151
1152
1153
1154
1155
1156
1157
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor,
                                                  mesh, arg_infos, result_infos)
1158
1159


1160
register_primitive(ScaledSoftmaxFwdPrimitive)
1161

1162
1163

def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
1164
    """
1165
1166
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
1167
    """
1168
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
1169
1170


1171
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
1172
    """
1173
    Scaled Softmax Bwd Primitive
1174
    """
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
    name = "te_scaled_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)
1187
1188

    @staticmethod
1189
    def abstract(dz_aval, softmax_out_aval, scale_factor):
1190
        """
1191
        te_scaled_softmax_backward abstract
1192
        """
1193
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)
1194

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1205

1206
        return out
1207

1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1227
1228

    @staticmethod
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor,
                                                   mesh, arg_infos, result_infos)


register_primitive(ScaledSoftmaxBwdPrimitive)


def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                       scale_factor: float) -> jnp.ndarray:
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                          softmax_out,
                                                          scale_factor=scale_factor)


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, mask_aval, scale_factor):    # pylint: disable=unused-argument
1278
        """
1279
        te_scaled_masked_softmax_forward abstract
1280
1281
        """

1282
1283
1284
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
1285

1286
1287
1288
1289
1290
1291
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
1292

1293
1294
1295
        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
1296
        ]
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
        assert pad_batch in (1, batch)    # 1 means broadcast
        assert mask_shape[-3] == 1    # 1 means broadcast
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """

        logits_aval, _ = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        mask_type = ir.RankedTensorType(mask.type)
        mask_shape = mask_type.shape
        pad_batch = reduce(operator.mul, mask_shape[:-3])

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits, mask]
        operand_shapes = [i_shape, mask_shape]
1329
1330
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1331
1332
1333
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype),
            scale_factor)
1334

1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

        return [out]

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits,
                                                                      mask,
                                                                      scale_factor=scale_factor)
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
        return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
            logits, mask, scale_factor=scale_factor), out_bdims

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = (logits_spec, mask_spec)
        out_shardings = logits_spec
        impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits,
                                                                mask,
                                                                scale_factor=scale_factor)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                                softmax_out,
                                                                scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
1500
1501
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
                                                  scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1586
1587
1588

        return out

1589
1590
1591
1592
1593
1594
1595
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor)
1596

1597
1598
1599
1600
1601
1602
1603
1604
    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)
1605

1606
1607
1608
1609
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1610

1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                                           scale_factor: float) -> jnp.ndarray:
1622
    """
1623
1624
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
1625
    """
1626
1627
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor)
1628
1629


1630
1631
@dataclass(frozen=True)
class FusedAttnHelper:
1632
    """
1633
    Helper for the fused attention backend
1634
    """
1635
1636
1637
1638
1639
1640
1641

    q_type: jnp.dtype
    kv_type: jnp.dtype
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
1642
1643
    num_heads_q: int
    num_heads_kv: int
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
    max_seqlen_q: int
    max_seqlen_kv: int
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
                                                             jax_dtype_to_te_dtype(self.kv_type),
                                                             self.qkv_layout, self.attn_bias_type,
                                                             self.attn_mask_type,
                                                             self.dropout_probability,
1659
                                                             self.num_heads_q, self.num_heads_kv,
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
                                                             self.max_seqlen_q, self.max_seqlen_kv,
                                                             self.head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
                f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(mask):
    """
    Generating cumsum seqlen for a batch
    """
    seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32)
    cu_seqlen = jnp.cumsum(seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))
    return cu_seqlen


class SelfFusedAttnFwdPrimitive(BasePrimitive):
    """
    Self Fused Attention Forward Primitive
    """
    name = "te_self_fused_attn_forward"
1719
    multiple_results = True
1720
1721
1722
    impl_static_args = (4, 5, 6, 7, 8)
    inner_primitive = None
    outer_primitive = None
1723
1724

    @staticmethod
1725
1726
    def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
1727
        """
1728
        Self fused attention fwd abstract
1729
        """
1730
1731
1732
1733
1734
1735
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del mask_or_cu_seqlen_aval, scaling_factor, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
        *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
        assert nqkv == 3
        assert qkv_aval.dtype == bias_aval.dtype
1736

1737
1738
        output_shape = (*batch_shape, max_seqlen, num_head, head_dim)
        output_dtype = qkv_dtype
1739

1740
        backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
1741
1742
                                  attn_mask_type, dropout_probability, num_head, num_head,
                                  max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
1743
1744
1745
1746
1747
1748
1749
1750

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
            softmax_dtype = qkv_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1)
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
1751
            raise ValueError(f'Unsupported {backend=}')
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype

        out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype)
        rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
        return out_aval, softmax_aux_aval, rng_state_aval
1763
1764

    @staticmethod
1765
1766
    def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1767
        """
1768
        Self fused attention fwd lowering rules
1769
        """
1770
        qkv_aval, _, _, _ = ctx.avals_in
1771

1772
1773
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1774

1775
1776
        operands = [qkv, bias, cu_seqlen, seed]
        operand_shapes = map(lambda x: x.type.shape, operands)
1777
        out_types = [
1778
1779
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1780
        ]
1781

1782
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
1783
1784
1785
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1786

1787
        out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
1788

1789
1790
1791
1792
1793
1794
        return out

    @staticmethod
    def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
             dropout_probability, is_training):
        assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
1795

1796
        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1797

1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
        output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
            qkv,
            bias,
            cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return output, softmax_aux, rng_state
1809

1810
1811
1812
1813
1814
1815
1816
1817
1818
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnFwdPrimitive.outer_primitive is not None
        qkv_bdim, _, _, seed_bdim = batch_dims

        out_bdims = qkv_bdim, qkv_bdim, seed_bdim
        return SelfFusedAttnFwdPrimitive.outer_primitive.bind(
1819
            *batched_args,
1820
1821
1822
1823
1824
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1825

1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
1838

1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
        impl = partial(SelfFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnFwdPrimitive)


def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray,
                        seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
1866
    """
1867
1868
    Wrapper for TE self fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
1869
    """
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
    return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
                                                          bias,
                                                          squeezed_mask,
                                                          seed,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
1885
1886


1887
class SelfFusedAttnBwdPrimitive(BasePrimitive):
1888
    """
1889
    Self Fused Attention Backward Primitive
1890
    """
1891
    name = "te_self_fused_attn_backward"
1892
    multiple_results = True
1893
    impl_static_args = (7, 8, 9, 10, 11)
1894
1895
    inner_primitive = None
    outer_primitive = None
1896
1897

    @staticmethod
1898
    def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
1899
1900
                 mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1901
        """
1902
        Self fused attention bwd abstract
1903
        """
1904
1905
        del softmax_aux_aval, rng_state_aval
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
1906
        del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
1907
1908
        del scaling_factor, dropout_probability, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
1909
1910
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
1911

1912
        dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
1913
1914
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
        return dqkv_aval, dbias_aval
1915
1916

    @staticmethod
1917
1918
    def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *,
                 attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
1919
        """
1920
        Self fused attention bwd lowering rules
1921
        """
1922
        qkv_aval, _, _, _, _, _, _ = ctx.avals_in
1923

1924
1925
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1926

1927
        operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
1928
        operand_shapes = map(lambda x: x.type.shape, operands)
1929
        out_types = [
1930
1931
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1932
        ]
1933

1934
1935
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1936
1937
1938
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1939

1940
        out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
1941
1942
1943

        return out

1944
    @staticmethod
1945
    def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
1946
1947
1948
1949
             attn_mask_type, scaling_factor, dropout_probability, is_training):
        assert SelfFusedAttnBwdPrimitive.inner_primitive is not None

        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1950

1951
1952
        dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
            qkv,
1953
            bias,
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
            softmax_aux,
            rng_state,
            output,
            doutput,
            cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return dqkv, dbias
1965

1966
1967
1968
1969
1970
1971
1972
1973
1974
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnBwdPrimitive.outer_primitive is not None
        qkv_bdim, *_ = batch_dims

        out_bdims = qkv_bdim, qkv_bdim
        return SelfFusedAttnBwdPrimitive.outer_primitive.bind(
1975
            *batched_args,
1976
1977
1978
1979
1980
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1981

1982
1983
1984
1985
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
1986
        del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
1987
1988
        del is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])
1989
        bias_spec = get_padded_spec(arg_infos[1])
1990
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
1991
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1992
1993
1994
1995
1996
1997
1998
        return (dx_sharding, dbias_sharding)

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
1999
        bias_spec = get_padded_spec(arg_infos[1])
2000
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
2001
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2002
2003
2004
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dx_sharding, dbias_sharding)

2005
        def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen):
2006
2007
            local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl(
                qkv,
2008
                bias,
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
                softmax_aux,
                rng_state,
                output,
                doutput,
                cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dx, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnBwdPrimitive)


2030
2031
2032
2033
2034
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
                        rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
                        squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
2035
    """
2036
2037
    Wrapper for TE self fused attention bwd
    Return the gradients of self fused attention with packed qkv input
2038
    """
2039
2040
2041
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
2042
    return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
2043
                                                          bias,
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
                                                          softmax_aux,
                                                          rng_state,
                                                          output,
                                                          doutput,
                                                          squeezed_mask,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
2054
2055


2056
class CrossFusedAttnFwdPrimitive(BasePrimitive):
2057
    """
2058
    Cross Fused Attention Forward Primitive
2059
    """
2060
    name = "te_cross_fused_attn_forward"
2061
    multiple_results = True
2062
    impl_static_args = (6, 7, 8, 9, 10)
2063
2064
    inner_primitive = None
    outer_primitive = None
2065
2066

    @staticmethod
2067
2068
2069
    def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval,
                 seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                 is_training):
2070
        """
2071
        Cross fused attention fwd abstract
2072
        """
2073
2074
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del scaling_factor, is_training
2075

2076
2077
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
2078

2079
2080
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
        *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
2081

2082
2083
2084
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)

        assert q_dtype == kv_dtype == bias_dtype
2085
2086
2087
2088
2089
2090
2091
2092
        assert q_batch_shape == kv_batch_shape
        assert q_num_head == kv_num_head
        assert q_head_dim == kv_head_dim
        assert nkv == 2
        assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype

        output_shape = q_aval.shape
        output_dtype = q_dtype
2093
2094

        backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
2095
2096
2097
                                  attn_bias_type, attn_mask_type, dropout_probability,
                                  q_num_head, kv_num_head,
                                  q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
            softmax_aux_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1)
            softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
            raise ValueError(f'Unsupported {backend=}')

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype
2113
2114
2115

        out_aval = q_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
2116
2117
2118
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)

        return out_aval, softmax_aux_aval, rng_state_aval
2119
2120

    @staticmethod
2121
2122
    def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2123
        """
2124
        Cross fused attention fwd lowering rules
2125
        """
2126
        q_aval, kv_aval, *_ = ctx.avals_in
2127
        assert q_aval.dtype == kv_aval.dtype
2128

2129
2130
2131
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2132

2133
        operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
2134
        operand_shapes = map(lambda x: x.type.shape, operands)
2135
        out_types = [
2136
2137
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2138
2139
        ]

2140
2141
2142
2143
2144
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2145

2146
        out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
2147
2148
2149

        return out

2150
    @staticmethod
2151
    def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
2152
2153
             scaling_factor, dropout_probability, is_training):
        assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
2154

2155
2156
        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
2157

2158
        output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
2159
2160
            q,
            kv,
2161
            bias,
2162
2163
2164
2165
2166
2167
2168
2169
            q_cu_seqlen,
            kv_cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2170
        return output, softmax_aux, rng_state
2171

2172
2173
2174
2175
2176
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
2177
        q_bdim, *_, seed_bdim = batch_dims
2178

2179
        out_bdims = q_bdim, q_bdim, seed_bdim
2180
        return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
2181
            *batched_args,
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
2199
2200
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
2211
2212
        rng_state_sharding = seed_sharding = NamedSharding(mesh,
                                                           PartitionSpec(get_all_mesh_axes(), None))
2213
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
2214
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
        impl = partial(CrossFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(CrossFusedAttnFwdPrimitive)


2227
2228
2229
2230
2231
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                         q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray,
                         seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2232
    """
2233
2234
    Wrapper for TE cross fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2235
    """
2236
2237
2238
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

2239
2240
2241
2242
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)

2243
2244
    return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
                                                           kv,
2245
                                                           bias,
2246
2247
2248
2249
2250
2251
2252
2253
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           seed,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2254
2255


2256
class CrossFusedAttnBwdPrimitive(BasePrimitive):
2257
    """
2258
    Cross Fused Attention Backward Primitive
2259
    """
2260
    name = "te_cross_fused_attn_backward"
2261
    multiple_results = True
2262
    impl_static_args = (9, 10, 11, 12, 13)
2263
2264
    inner_primitive = None
    outer_primitive = None
2265
2266

    @staticmethod
2267
2268
2269
    def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
                 doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2270
        """
2271
        Cross fused attention bwd abstract
2272
        """
2273
2274
        del softmax_aux_aval, rng_state_aval, output_aval
        del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training
2275
2276
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
2277
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
2278
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
2279
        assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
2280
        assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
2281

2282
2283
        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
2284
2285
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
        return dq_aval, dkv_aval, dbias_aval
2286
2287

    @staticmethod
2288
2289
2290
    def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                 kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
2291
        """
2292
        Cross fused attention bwd lowering rules
2293
        """
2294
        q_aval, kv_aval, *_ = ctx.avals_in
2295
        assert q_aval.dtype == kv_aval.dtype
2296

2297
2298
2299
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2300

2301
        operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
2302
        operand_shapes = map(lambda x: x.type.shape, operands)
2303
        out_types = [
2304
2305
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2306
        ]
2307

2308
2309
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2310
2311
2312
2313
2314
2315
        # the dropout elements are encoded in the forward auxiliary tensor
        # so seed is not needed in backward
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2316

2317
        out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
2318
2319
2320

        return out

2321
    @staticmethod
2322
2323
2324
    def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask,
             kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
             is_training):
2325
2326
2327
2328
2329
        assert CrossFusedAttnBwdPrimitive.inner_primitive is not None

        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)

2330
        dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
2331
2332
            q,
            kv,
2333
            bias,
2334
            softmax_aux,
2335
2336
            rng_state,
            output,
2337
2338
2339
2340
2341
2342
2343
2344
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2345
        return dq, dkv, dbias
2346
2347
2348
2349
2350
2351
2352
2353

    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, kv_bdim, *_ = batch_dims

2354
        out_bdims = q_bdim, kv_bdim, q_bdim
2355
        return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
2356
            *batched_args,
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
2371
        bias_spec = get_padded_spec(arg_infos[2])
2372
2373
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
2374
2375
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dkv_sharding, dbias_sharding)
2376
2377
2378
2379
2380
2381
2382

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
2383
        bias_spec = get_padded_spec(arg_infos[2])
2384
2385
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
2386
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2387
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
        out_shardings = (dq_sharding, dkv_sharding, dbias_sharding)

        def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                         kv_cu_seqlen):
            local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl(
                q,
                kv,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dq, local_dkv, global_dbias
2411

2412
        return mesh, sharded_impl, out_shardings, arg_shardings
2413

2414

2415
register_primitive(CrossFusedAttnBwdPrimitive)
2416
2417


2418
2419
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                         softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
2420
2421
2422
2423
                         doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
                         kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2424
    """
2425
2426
    Wrapper for TE cross fused attention bwd
    Return the gradients of cross fused attention with packed kv input
2427
    """
2428
2429
2430
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
2431
2432
    return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
                                                           kv,
2433
                                                           bias,
2434
                                                           softmax_aux,
2435
2436
                                                           rng_state,
                                                           output,
2437
2438
2439
2440
2441
2442
2443
2444
                                                           doutput,
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2445
2446


2447
class GatedGeluPrimitive(BasePrimitive):
2448
    """
2449
    Gated Gelu Froward Primitive
2450
    """
2451
    name = "te_gated_gelu"
2452
    multiple_results = False
2453
2454
2455
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2456
2457

    @staticmethod
2458
    def abstract(x_aval):
2459
        """
2460
        gated_gelu abstract
2461
        """
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        x_shape = x_aval.shape
        assert x_shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        x_shape = x_aval.shape
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)
2472

2473
        return out_aval
2474
2475

    @staticmethod
2476
    def lowering(ctx, x):
2477
        """
2478
        gated_gelu lowering rules
2479
        """
2480
2481
2482
2483
2484
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
2485

2486
2487
2488
2489
2490
2491
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
        ]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
2492

2493
2494
2495
2496
2497
        hidden_size = ir_x_shape[-1]
        batch_size = reduce(operator.mul, ir_x_shape[:-2])
        in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
                                                               in_dtype)
2498

2499
        out = custom_caller(GatedGeluPrimitive.name, args, opaque, False)
2500

2501
        return [out]
2502

2503
2504
2505
2506
2507
    @staticmethod
    def impl(x):
        assert GatedGeluPrimitive.inner_primitive is not None
        out = GatedGeluPrimitive.inner_primitive.bind(x)
        return out
2508

2509
2510
2511
2512
2513
2514
2515
2516
2517
    @staticmethod
    def batcher(batched_args, batch_dims):
        """
        gated_gelu batcher
        """
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluPrimitive.outer_primitive is not None
        inputs, = batched_args
        inputs_bdim, = batch_dims
2518

2519
2520
        out_bdims = inputs_bdim
        return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims
2521

2522
2523
2524
2525
2526
2527
2528
2529
2530
    @staticmethod
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
        """
        gated_gelu infer_sharding_from_operands
        """
        del result_infos    # Unused.
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding
2531

2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        gated_gelu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        impl = GatedGeluPrimitive.impl
        return mesh, impl, out_sharding, arg_shardings
2543
2544


2545
register_primitive(GatedGeluPrimitive)
2546
2547


2548
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray:
2549
    """
2550
2551
2552
    gated gelu wrapper
    Return FP8(geglu(inputs))
    Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
2553
    """
2554
    return GatedGeluPrimitive.outer_primitive.bind(inputs)
2555
2556


2557
class DgatedGeluPrimitive(BasePrimitive):
2558
    """
2559
    Dgated Gelu Primitive
2560
    """
2561
2562
2563
2564
2565
    name = "te_dgated_gelu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2566
2567

    @staticmethod
2568
    def abstract(dz_aval, x_aval):
2569
        """
2570
        dgated_gelu abstract
2571
        """
2572
2573
2574
2575
2576
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
2577

2578
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
2579

2580
2581
2582
2583
2584
        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)
        return out_aval
2585
2586

    @staticmethod
2587
    def lowering(ctx, dz, x):
2588
        """
2589
        dgated_gelu lowering rules
2590
        """
2591
2592
2593
2594
2595
2596
2597
2598
2599
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
        ir_in_type = ir.RankedTensorType(dz.type)
        ir_in_shape = ir_in_type.shape
        gi_type = ir.RankedTensorType(x.type)
        gi_shape = gi_type.shape
        for axis in range(len(ir_in_shape) - 1):
            assert ir_in_shape[axis] == gi_shape[axis]
2600

2601
2602
2603
2604
2605
2606
        ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
        i_hidden_size = ir_in_shape[-1]
        g_hidden_size = gi_shape[-1]
        assert i_hidden_size == g_hidden_size
        out_dtype = ir_in_type.element_type
        out_shape = gi_shape
2607
2608

        out_types = [
2609
            ir.RankedTensorType.get(out_shape, out_dtype),
2610
        ]
2611
2612
        operands = [dz, x]
        operand_shapes = [ir_in_shape, gi_shape]
2613
2614
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2615
2616
2617
        in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
                                                               in_dtype, in_dtype)
2618

2619
        out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False)
2620
2621
2622
2623

        return [out]

    @staticmethod
2624
2625
2626
2627
2628
2629
2630
    def impl(dz, x):
        """
        dgated_gelu implementation
        """
        assert DgatedGeluPrimitive.inner_primitive is not None
        dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x)
        return dx
2631
2632

    @staticmethod
2633
    def batcher(batched_args, batch_dims):
2634
        """
2635
        dgated_gelu batcher
2636
        """
2637
2638
2639
2640
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims
2641

2642
2643
        out_bdims = x_bdim
        return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
2644
2645

    @staticmethod
2646
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
2647
        """
2648
        dgated_gelu infer_sharding_from_operands
2649
        """
2650
2651
2652
2653
        del result_infos    # Unused.
        gelu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
        return dx_sharding
2654

2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        dgated_gelu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding
        impl = DgatedGeluPrimitive.impl
        return mesh, impl, out_shardings, arg_shardings
2666
2667


2668
register_primitive(DgatedGeluPrimitive)
2669
2670


2671
2672
2673
2674
2675
2676
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
    """
    dgated_gelu fusion wrapper
    Return dgeglu(inputs)
    """
    return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
2677
2678


2679
2680
def _normalize_axis_boundary(axis, ndim):
    return axis if axis >= 0 else ndim + axis
2681
2682


2683
def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
2684
    """
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
2703
    """
2704
2705
2706
2707
2708
2709
2710
2711
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes
    assert static_axis_boundary < len(shape) - 2    # at least 2 remaining for transpose.
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
    return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:],
            *shape[transpose_start_idx:transpose_axis_boundary])
2712
2713


2714
class CastTransposePrimitive(BasePrimitive):
2715
    """
2716
    Cast Transpose Primitive
2717
    """
2718
2719
2720
2721
2722
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None
2723
2724

    @staticmethod
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval
2744
2745

    @staticmethod
2746
2747
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
2748
        """
2749
        te_cast_transpose_p lowering rules
2750
        """
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(CastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 2})

        return out
2793
2794

    @staticmethod
2795
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
2796
        """
2797
        te_cast_transpose implementation
2798
        """
2799
2800
2801
2802
2803
2804
2805
        assert CastTransposePrimitive.inner_primitive is not None
        casted_x, casted_transposed_x, updated_amax = \
            CastTransposePrimitive.inner_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary)
        return casted_x, casted_transposed_x, updated_amax
2806

2807
2808
2809
2810
2811
2812
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
                transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2813

2814
2815
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims
2816

2817
2818
2819
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2820

2821
2822
2823
2824
2825
2826
2827
2828
2829
        out_bdims = x_bdim, x_bdim, amax_bdim
        return CastTransposePrimitive.outer_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2830

2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
                                     arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                  result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
            local_cx, local_cxt, local_updated_amax = \
                CastTransposePrimitive.impl(x, amax, scale, scale_inv,
                     out_dtype=out_dtype,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype, static_axis_boundary: int,
                   transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
2873
    """
2874
2875
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
2876
    """
2877
2878
2879
2880
2881
2882
2883
2884
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary)
2885
2886


2887
class TransposePrimitive(BasePrimitive):
2888
    """
2889
    Transpose Primitive
2890
    """
2891
    name = "te_transpose"
2892
    multiple_results = False
2893
2894
2895
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None
2896
2897

    @staticmethod
2898
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
2899
        """
2900
        _transpose abstract
2901
        """
2902
2903
2904
        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
2905

2906
        return xt_aval
2907
2908

    @staticmethod
2909
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
2910
        """
2911
        _transpose cuda lowering
2912
2913
        """

2914
2915
2916
2917
        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
            jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2
        ]
2918

2919
2920
2921
2922
2923
2924
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
2925

2926
2927
2928
2929
2930
2931
        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
        operands = [x]
        operand_shapes = [ir_x_shape]
2932
2933
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2934
2935
2936
2937
2938
        te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype,
                                                               te_dtype)
2939

2940
        out = custom_caller(TransposePrimitive.name, args, opaque, False)
2941
2942
2943

        return [out]

2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
        transposed_x = \
            TransposePrimitive.inner_primitive.bind(x,
                                                    static_axis_boundary=static_axis_boundary,
                                                    transpose_axis_boundary=transpose_axis_boundary)
        return transposed_x
2955

2956
2957
2958
2959
2960
    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2961

2962
2963
        x, = batched_args
        x_bdim, = batch_dims
2964

2965
2966
2967
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2968

2969
2970
2971
2972
        out_bdims = x_bdim
        return TransposePrimitive.outer_primitive.bind(
            x, static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2973
2974

    @staticmethod
2975
2976
2977
2978
2979
2980
2981
    def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding
2982
2983

    @staticmethod
2984
2985
2986
2987
2988
2989
2990
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding
2991

2992
2993
2994
        impl = partial(TransposePrimitive.impl,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
2995

2996
        return mesh, impl, out_shardings, arg_shardings
2997
2998


2999
register_primitive(TransposePrimitive)
3000
3001


3002
3003
def transpose(x: jnp.ndarray, static_axis_boundary: int,
              transpose_axis_boundary: int) -> jnp.ndarray:
3004
    """
3005
    transpose wrapper
3006
    """
3007
3008
3009
    return TransposePrimitive.outer_primitive.bind(x,
                                                   static_axis_boundary=static_axis_boundary,
                                                   transpose_axis_boundary=transpose_axis_boundary)
3010
3011


3012
class LayerNormFwdFp8Primitive(BasePrimitive):
3013
    """
3014
    Layer Normalization Forward FP8 Primitive
3015
    """
3016
3017
3018
3019
3020
    name = "te_layernorm_forward_fp8"
    multiple_results = True
    impl_static_args = (6, 7, 8)    # out_type, zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
3021
3022

    @staticmethod
3023
3024
    def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 zero_centered_gamma, epsilon):
3025
        """
3026
        LayerNorm fwd (fp8 out) abstract
3027
        """
3028
3029
        del zero_centered_gamma, epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3030

3031
3032
3033
3034
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3035

3036
3037
3038
3039
3040
3041
3042
3043
3044
        mu_rsigama_dtype = jnp.float32

        assert gamma_aval.size == beta_aval.size

        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, mu_aval, rsigma_aval, updated_amax_aval
3045
3046

    @staticmethod
3047
3048
    def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma,
                 epsilon):
3049
        """
3050
        LayerNorm fwd (fp8 out) lowering rules
3051
        """
3052
        x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3053

3054
3055
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3056

3057
3058
3059
3060
3061
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gamma_aval.dtype == beta_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3062

3063
3064
3065
3066
3067
3068
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
3069

3070
3071
        assert g_type == b_type
        assert g_shape == b_shape
3072

3073
3074
3075
3076
3077
3078
3079
3080
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3081

3082
3083
3084
3085
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3086

3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, beta, amax, scale, scale_inv]
        operand_shapes = [
            x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3098

3099
3100
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3101
3102
3103
3104
3105
3106
3107
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
3108
            sm_margin,
3109
        )
3110

3111
3112
3113
3114
3115
        out = custom_caller(LayerNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={3: 3})
3116

3117
        return out
3118
3119

    @staticmethod
3120
    def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon):
3121
        """
3122
        to describe implementation
3123
        """
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
        assert LayerNormFwdFp8Primitive.inner_primitive is not None
        out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
        return out, mu, rsigma, updated_amax
3136
3137

    @staticmethod
3138
    def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon):
3139
        """
3140
        to describe batch rules for vmap
3141
        """
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, beta, amax, scale, scale_inv = batched_args
        x_bdim, _, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
        return LayerNormFwdFp8Primitive.outer_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos,
                                     result_infos):
        del out_dtype, zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance.")

        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
3195

3196
3197
3198
3199
3200
3201
3202
        def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
            local_x, local_mu, local_rsigma, local_amax = \
                LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv,
                                            out_dtype=out_dtype,
                                            zero_centered_gamma=zero_centered_gamma,
                                            epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3203

3204
            return local_x, local_mu, local_rsigma, global_updated_amax
3205

3206
        return mesh, sharded_impl, out_shardings, arg_shardings
3207

3208
3209
3210
3211
3212
3213
3214

register_primitive(LayerNormFwdFp8Primitive)


def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
                      scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype,
                      zero_centered_gamma: bool, epsilon: float):
3215
    """
3216
    Wrapper for TE layernorm fwd (fp8 out)
3217
    """
3218
3219
3220
3221
3222
3223
3224
3225
3226
    return LayerNormFwdFp8Primitive.outer_primitive.bind(x,
                                                         gamma,
                                                         beta,
                                                         amax,
                                                         scale,
                                                         scale_inv,
                                                         out_dtype=out_dtype,
                                                         zero_centered_gamma=zero_centered_gamma,
                                                         epsilon=epsilon)
3227
3228


3229
class RmsNormFwdFp8Primitive(BasePrimitive):
3230
    """
3231
    RMS Normalization Forward FP8 Primitive
3232
    """
3233
3234
3235
3236
3237
    name = "te_rmsnorm_forward_fp8"
    multiple_results = True
    impl_static_args = (5, 6)    # out_dtype, epsilon
    inner_primitive = None
    outer_primitive = None
3238

3239
3240
    @staticmethod
    def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
3241
        """
3242
        RMSNorm fwd (fp8 out) abstract
3243
        """
3244
3245
        del epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3246

3247
3248
3249
3250
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3251

3252
3253
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
3254

3255
        rsigama_dtype = jnp.float32
3256

3257
3258
3259
        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
        amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3260

3261
        return out_aval, rsigma_aval, amax_aval
3262
3263

    @staticmethod
3264
    def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
3265
        """
3266
        RMSNorm fwd (fp8 out) lowering rules
3267
3268
        """

3269
3270
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3271

3272
        x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3273

3274
3275
3276
3277
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3278

3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3296

3297
3298
3299
3300
3301
3302
3303
3304
3305
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, amax, scale, scale_inv]
        operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

3306
3307
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3308
3309
3310
3311
3312
3313
3314
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
3315
            sm_margin,
3316
3317
        )

3318
3319
3320
3321
3322
3323
3324
3325
        out = custom_caller(RmsNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out

3326
    @staticmethod
3327
    def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
3328
        """
3329
        to describe implementation
3330
        """
3331
3332
3333
3334
3335
3336
3337
3338
3339
        assert RmsNormFwdFp8Primitive.inner_primitive is not None
        out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
                                                                        gamma,
                                                                        amax,
                                                                        scale,
                                                                        scale_inv,
                                                                        out_dtype=out_dtype,
                                                                        epsilon=epsilon)
        return out, rsigma, amax
3340

3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
        """
        to describe batch rules for vmap
        """
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
        out_bdims = x_bdim, x_bdim, amax_bdim
        return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                           gamma,
                                                           amax,
                                                           scale,
                                                           scale_inv,
                                                           out_dtype=out_dtype,
                                                           epsilon=epsilon), out_bdims
3358

3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
    @staticmethod
    def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del out_dtype, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, rsigma_sharding, amax_sharding)
3373

3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
    @staticmethod
    def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
3392

3393
3394
3395
3396
3397
        def sharded_impl(x, gamma, amax, scale, scale_inv):
            local_x, local_rsigma, local_amax= \
                RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv,
                                            out_dtype=out_dtype, epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3398

3399
            return local_x, local_rsigma, global_updated_amax
3400

3401
        return mesh, sharded_impl, out_shardings, arg_shardings
3402
3403


3404
register_primitive(RmsNormFwdFp8Primitive)
3405

3406
3407
3408

def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                    scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float):
3409
    """
3410
    Wrapper for TE rmsnorm fwd (fp8 out)
3411
    """
3412
3413
3414
3415
3416
3417
3418
    return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                       gamma,
                                                       amax,
                                                       scale,
                                                       scale_inv,
                                                       out_dtype=out_dtype,
                                                       epsilon=epsilon)
3419
3420


3421
class GatedGeluFp8Primitive(BasePrimitive):
3422
    """
3423
    Gated Gelu FP8 Primitive
3424
    """
3425
    name = "te_gated_gelu_fp8"
3426
    multiple_results = True
3427
3428
3429
    impl_static_args = (4,)    #out_dtype
    inner_primitive = None
    outer_primitive = None
3430
3431

    @staticmethod
3432
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
3433
        """
3434
        te_gated_gelu_p abstract
3435
        """
3436
3437
3438
3439
3440
3441
3442
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3443

3444
3445
3446
3447
3448
3449
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3450

3451
        return out_aval, updated_amax_aval
3452
3453

    @staticmethod
3454
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
3455
        """
3456
        te_gated_gelu_p lowering rules
3457
        """
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3471

3472
3473
3474
3475
        hidden_size = ir_x_shape[-1]
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
3476
        out_types = [
3477
3478
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
3479
        ]
3480
3481
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
3482
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3483

3484
3485
3486
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))
3487

3488
3489
3490
3491
3492
        out = custom_caller(GatedGeluFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 1})
3493
3494
3495
3496

        return out

    @staticmethod
3497
    def impl(x, amax, scale, scale_inv, out_dtype):
3498
        """
3499
        to describe implementation
3500
        """
3501
3502
3503
3504
3505
3506
3507
        assert GatedGeluFp8Primitive.inner_primitive is not None
        out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x,
                                                                       amax,
                                                                       scale,
                                                                       scale_inv,
                                                                       out_dtype=out_dtype)
        return out, updated_amax
3508
3509

    @staticmethod
3510
    def batcher(batched_args, batch_dims, *, out_dtype):
3511
        """
3512
        to describe batch rules for vmap
3513
        """
3514
3515
3516
3517
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims
3518

3519
3520
3521
3522
3523
3524
        out_bdims = x_bdim, amax_bdim
        return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                          amax,
                                                          scale,
                                                          scale_inv,
                                                          out_dtype=out_dtype), out_bdims
3525

3526
3527
3528
3529
3530
3531
3532
    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)
3533

3534
3535
3536
3537
3538
3539
3540
3541
    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)
3542

3543
3544
3545
3546
3547
3548
3549
        def sharded_impl(x, amax, scale, scale_inv):
            local_x, local_amax = GatedGeluFp8Primitive.impl(x,
                                                             amax,
                                                             scale,
                                                             scale_inv,
                                                             out_dtype=out_dtype)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3550

3551
            return local_x, global_updated_amax
3552

3553
        return mesh, sharded_impl, out_shardings, arg_shardings
3554
3555


3556
register_primitive(GatedGeluFp8Primitive)
3557

3558
3559
3560

def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3561
    """
3562
3563
    gated gelu wrapper
    Return FP8(geglu(x))
3564
    """
3565
3566
3567
3568
3569
    return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                      amax,
                                                      scale,
                                                      scale_inv,
                                                      out_dtype=out_dtype)
3570
3571


3572
class DgatedGeluCastTransposePrimitive(BasePrimitive):
3573
    """
3574
    Dgated Gelu Cast Transpose Primitive
3575
    """
3576
    name = "te_dgated_gelu_cast_transpose"
3577
    multiple_results = True
3578
3579
3580
    impl_static_args = (5, 6)    # out_dtype, static_axis_boundary
    inner_primitive = None
    outer_primitive = None
3581
3582

    @staticmethod
3583
3584
    def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 static_axis_boundary):
3585
        """
3586
        te_dgated_gelu_cast_transpose_p abstract
3587
        """
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert x_aval.shape[-2] == 2    # Linear + GeLU
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval
3603

3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
        """
        te_dgated_gelu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        assert x_shape[-2] == 2    # Linear + GeLU
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(dz_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(DgatedGeluCastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out
3653
3654

    @staticmethod
3655
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
3656
        """
3657
        to describe implementation
3658
        """
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
        assert DgatedGeluCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary)
        return out, t_out, updated_amax
3669

3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
3680

3681
3682
3683
3684
        out_bdims = x_bdim, x_bdim, amax_bdim
        return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
            dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
            static_axis_boundary=x_bdim), out_bdims
3685

3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)
3696

3697
3698
3699
3700
3701
3702
3703
    @staticmethod
    def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
3704

3705
3706
3707
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
3708

3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_updated_amax
3720

3721
        return mesh, sharded_impl, out_shardings, arg_shardings
3722
3723


3724
register_primitive(DgatedGeluCastTransposePrimitive)
3725

3726
3727
3728
3729
3730

def dgated_gelu_cast_transpose(
        dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
        scale_inv: jnp.ndarray, out_dtype: TEDType,
        static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3731
    """
3732
3733
    cast transpose d_gated_gelu fusion wrapper
    Return FP8(dgeglu(inputs))
3734
    """
3735
3736
3737
3738
3739
3740
3741
3742
    return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary)