cpp_extensions.py 148 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te custom call"""

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
11
import os
12
13
import warnings

14
15
16
17
18
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.interpreters import xla, mlir
19
from jax.experimental.custom_partitioning import custom_partitioning
20
from jax.interpreters.mlir import ir, dtype_to_ir_type
21
22
from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
23

24
25
26
27
28
29
30
try:
    from jaxlib.hlo_helpers import custom_call
except ImportError:
    # Newer JAX changed its API. But we want to support a few JAX
    # version, so we still need this import.
    pass

31
32
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
33
34
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
35
36
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
37

38
39
40
41
42
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec

43
44
45
46
47
48
49
50
51
for _name, _value in transformer_engine_jax.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="CUDA")


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)
67
68
69
70
71
72
73
74
75


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


76
77
78
79
80
81
82
def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


83
84
85
86
def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
87
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
88

89
90
91
92
93
94
95
96
97
    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
    }
98

99
100
    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")
101

102
    return converter.get(jax_dtype)
103
104


105
106
107
108
109
110
111
112
def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)
113
114


115
def _check_valid_batch_dims(bdims):
116
    """
117
    Assert out non-supported bath dims
118
    """
119
120
121
122
    for dim in bdims:
        assert dim in [0, None], \
            "Currently only support batch_dim in [0, None], " \
            f"but got {dim=}"
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145


class BasePrimitive(metaclass=ABCMeta):
    """
    jax premitive
    """

    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

178
179
180
181
182

def register_primitive(cls):
    """
    register jax primitive
    """
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251


@dataclass
class CustomCallArgsWrapper:
    """
    wrapper of XLA custom call args
    """

    def __init__(self,
                 output_types,
                 operands,
                 operand_shapes,
                 operand_specific_layouts=None,
                 output_specific_layouts=None):
        self.output_types = output_types
        self.operands = operands
        self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
                                                                      operand_specific_layouts)
        output_shapes = [x.shape for x in output_types]
        self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
                                                                     output_specific_layouts)

    @staticmethod
    def generate_layouts(shapes, specific_layouts):
        """
        setup layouts for XLA custom call
        """

        def default_layout(shape):
            return range(len(shape) - 1, -1, -1)

        if specific_layouts is None:
            specific_layouts = {}

        layouts = []
        for idx, shape in enumerate(shapes):
            if idx in specific_layouts:
                layouts.append(specific_layouts[idx])
            else:
                layouts.append(default_layout(shape))
        return layouts


def custom_caller(name, args, opaque, has_side_effect, **kwargs):
    """
    XLA custom call warpper
    """
252
253
254
255
256
257
258
259
260
261
262
263
264
    if hasattr(mlir, "custom_call"):
        out = mlir.custom_call(name,
                               result_types=args.output_types,
                               operands=args.operands,
                               operand_layouts=args.operand_layouts,
                               result_layouts=args.output_layouts,
                               backend_config=opaque,
                               has_side_effect=has_side_effect,
                               **kwargs).results
    else:
        # Need to disable one pylint error as the second function
        # parameter name recenctly in JAX. Otherwise we won't be
        # compatible with multiple JAX version.
265
266
267
268
269
270
271
272
273
        out = custom_call(    # pylint: disable=too-many-function-args
            name,
            args.output_types,
            operands=args.operands,
            operand_layouts=args.operand_layouts,
            result_layouts=args.output_layouts,
            backend_config=opaque,
            has_side_effect=has_side_effect,
            **kwargs)
274
275
276
    return out


277
class LayerNormFwdPrimitive(BasePrimitive):
278
    """
279
    Layer Normalization Forward Primitive
280
    """
281
282
283
284
285
    name = "te_layernorm_forward"
    multiple_results = True
    impl_static_args = (3, 4)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
286
287

    @staticmethod
288
    def abstract(x_aval, gamma_aval, beta_aval, **kwargs):    # pylint: disable=unused-argument
289
        """
290
        LayerNorm fwd abstract
291
        """
292
293
294
295
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        mu_rsigama_dtype = jnp.float32
296

297
298
        out_aval = core.raise_to_shaped(x_aval)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
299

300
301
302
303
304
        assert gamma_aval.size == beta_aval.size
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0

        return out_aval, mu_aval, rsigma_aval
305
306

    @staticmethod
307
    def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
308
        """
309
        LayerNorm fwd lowering rules
310
        """
311
312
313
314
315
316
317
318
        x_aval, gamma_aval, beta_aval = ctx.avals_in
        assert gamma_aval.dtype == beta_aval.dtype
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
319

320
321
        assert g_type == b_type
        assert g_shape == b_shape
322

323
324
325
326
327
        # Output shape is same as the input shape, but the output type is same as the weight type.
        # See ln_api.cpp
        output_type = g_type.element_type
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
328

329
330
331
332
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
333

334
335
336
337
338
339
340
341
        out_types = [
            ir.RankedTensorType.get(out_shape, output_type),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
        ]
        operands = [x, gamma, beta]
        operand_shapes = [x_shape, g_shape, b_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
342

343
344
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

345
346
347
348
349
350
351
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
352
            sm_margin,
353
        )
354

355
        out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
356

357
        return out
358
359

    @staticmethod
360
    def impl(x, gamma, beta, zero_centered_gamma, epsilon):
361
        """
362
        to describe implementation
363
        """
364
365
366
367
368
369
370
        assert LayerNormFwdPrimitive.inner_primitive is not None
        out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind(
            x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return out, mu, rsigma

    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
371
        """
372
        to describe batch rules for vmap
373
        """
374
375
376
377
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdPrimitive.outer_primitive is not None
        x, gamma, beta = batched_args
        x_bdim, _, _ = batch_dims
378

379
380
381
382
383
384
        out_bdims = x_bdim, x_bdim, x_bdim
        return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                          gamma,
                                                          beta,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
385

386
387
388
389
390
391
392
393
394
395
396
397
398
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, mu_sharding, rsigma_sharding)
399

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
415

416
417
418
419
420
421
        arg_shardings = (x_sharding, g_sharding, b_sharding)
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
        impl = partial(LayerNormFwdPrimitive.impl,
                       zero_centered_gamma=zero_centered_gamma,
                       epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
422
423


424
register_primitive(LayerNormFwdPrimitive)
425
426


427
428
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
                  epsilon: float):
429
    """
430
    Wrapper for TE layernorm fwd
431
    """
432
433
434
435
436
    return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                      gamma,
                                                      beta,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
437
438


439
class LayerNormBwdPrimitive(BasePrimitive):
440
    """
441
    Layer Normalization Backward Primitive
442
    """
443
444
445
446
447
    name = "te_layernorm_backward"
    multiple_results = True
    impl_static_args = (5, 6)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
448
449

    @staticmethod
450
    def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
451
        """
452
        Layernorm bwd abstract
453
        """
454
455
456
457
458
459
460
461
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
        assert mu_dtype == rsigma_dtype == jnp.float32
462

463
464
465
        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval, dbeta_aval
466
467

    @staticmethod
468
    def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
469
        """
470
        Layernorm bwd lowering rules
471
        """
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        _, x_aval, _, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(gamma.type)
        b_shape = b_type.shape
        assert g_type == b_type
        assert g_shape == b_shape

        dz_shape = ir.RankedTensorType(dz.type).shape
        mu_shape = ir.RankedTensorType(mu.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
488
489

        out_types = [
490
491
492
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
            ir.RankedTensorType.get(b_shape, b_type.element_type),
493
        ]
494
495
        operands = [dz, mu, rsigma, x, gamma]
        operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
496
497
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

498
499
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

500
501
502
503
504
505
506
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
507
            sm_margin,
508
        )
509

510
        out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
511

512
        return out
513

514
515
516
517
518
519
    @staticmethod
    def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
        assert LayerNormBwdPrimitive.inner_primitive is not None
        dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind(
            dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return dx, dgamma, dbeta
520

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert LayerNormBwdPrimitive.outer_primitive is not None
        dz, x, mu, rsigma, gamma = batched_args
        _, x_bdim, _, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim, gamma_bdim
        return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                          x,
                                                          mu,
                                                          rsigma,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        return dx_sharding, dgamma_sharding, dbeta_sharding
551

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
        arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec)))

        def sharded_impl(dz, x, mu, rsigma, gamma):
            local_dx, local_dgamma, local_dbeta = \
                LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma,
                     zero_centered_gamma=zero_centered_gamma,
                     epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
            return local_dx, global_dgamma, global_dbeta

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(LayerNormBwdPrimitive)


def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray,
                  gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
587
    """
588
    Wrapper for TE layernorm bwd
589
    """
590
591
592
593
594
595
596
    return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                      x,
                                                      mu,
                                                      rsigma,
                                                      gamma,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
597
598


599
class RmsNormFwdPrimitive(BasePrimitive):
600
    """
601
    RMS Normalization Forward Primitive
602
    """
603
    name = "te_rmsnorm_forward"
604
    multiple_results = True
605
606
607
    impl_static_args = (2,)    # epsilon
    inner_primitive = None
    outer_primitive = None
608
609

    @staticmethod
610
    def abstract(x_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
611
        """
612
        RMSNorm fwd abstract
613
        """
614
615
616
617
618
619
620
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        rsigama_dtype = jnp.float32

        out_aval = core.raise_to_shaped(x_aval)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
621

622
623
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
624

625
        return out_aval, rsigma_aval
626
627

    @staticmethod
628
    def lowering(ctx, x, gamma, *, epsilon):
629
        """
630
        RMSNorm fwd lowering rules
631
        """
632
633
634
635
636
637
638
639
640
641
642
        x_aval, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        rsigma_element_type = ir.F32Type.get()

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
643
644

        out_types = [
645
646
            ir.RankedTensorType.get(out_shape, x_type.element_type),
            ir.RankedTensorType.get(batch_shape, rsigma_element_type),
647
        ]
648
649
        operands = [x, gamma]
        operand_shapes = [x_shape, g_shape]
650
651
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

652
653
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

654
655
656
657
658
659
660
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
661
            sm_margin,
662
        )
663

664
        out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
665
666
667
668

        return out

    @staticmethod
669
    def impl(x, gamma, epsilon):
670
        """
671
        to describe implementation
672
        """
673
674
675
        assert RmsNormFwdPrimitive.inner_primitive is not None
        out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
        return out, rsigma
676
677

    @staticmethod
678
    def batcher(batched_args, batch_dims, *, epsilon):
679
        """
680
        to describe batch rules for vmap
681
        """
682
683
684
685
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdPrimitive.outer_primitive is not None
        x, gamma = batched_args
        x_bdim, _ = batch_dims
686

687
688
        out_bdims = x_bdim, x_bdim
        return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
689

690
691
692
693
694
695
696
697
698
699
700
701
702
    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, rsigma_sharding)
703

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (x_sharding, g_sharding)
        out_shardings = (out_sharding, rsigma_sharding)
        impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
722
723


724
register_primitive(RmsNormFwdPrimitive)
725
726


727
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
728
    """
729
    Wrapper for TE rmsnorm fwd
730
    """
731
    return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
732
733


734
class RmsNormBwdPrimitive(BasePrimitive):
735
    """
736
    RMS Normalization Backward Primitive
737
    """
738
    name = "te_rmsnorm_backward"
739
    multiple_results = True
740
741
742
    impl_static_args = (4,)    # epsilon
    inner_primitive = None
    outer_primitive = None
743
744

    @staticmethod
745
746
747
748
749
750
751
    def abstract(
            dz_aval,
            x_aval,
            rsigma_aval,
            gamma_aval,
            **kwargs    # pylint: disable=unused-argument
    ):
752
        """
753
        RMSNorm bwd abstract
754
        """
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert rsigma_aval.shape == x_aval.shape[:-1]
        assert rsigma_dtype == jnp.float32

        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval

    @staticmethod
    def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
769
        """
770
        RMSNorm bwd lowering rules
771
        """
772
773
774
775
776
777
778
779
780
781
        _, x_aval, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        dz_shape = ir.RankedTensorType(dz.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
782
783

        out_types = [
784
785
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
786
        ]
787
788
        operands = [dz, rsigma, x, gamma]
        operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
789
790
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

791
792
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

793
794
795
796
797
798
799
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
800
            sm_margin,
801
        )
802

803
        out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
804
805
806

        return out

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
    @staticmethod
    def impl(dz, x, rsigma, gamma, epsilon):
        assert RmsNormBwdPrimitive.inner_primitive is not None
        dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
        return dx, dgamma

    @staticmethod
    def batcher(batched_args, batch_dims, *, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert RmsNormBwdPrimitive.outer_primitive is not None
        dz, x, rsigma, gamma = batched_args
        _, x_bdim, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim
        return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma,
                                                        epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        return dx_sharding, dgamma_sharding

    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_shardings = dx_sharding, dgamma_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec)))

        def sharded_impl(dz, x, rsigma, gamma):
            local_dx, local_dgamma = \
                RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            return local_dx, global_dgamma

        return mesh, sharded_impl, out_shardings, arg_shardings

865

866
register_primitive(RmsNormBwdPrimitive)
867
868


869
870
def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray,
                epsilon: float):
871
    """
872
    Wrapper for TE layernorm bwd
873
    """
874
    return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
875
876


877
class SoftmaxPrimitive(BasePrimitive):
878
    """
879
    Softmax Primitive
880
    """
881
    max_k_seqlen_supported = 4096
882
883

    @staticmethod
884
885
886
887
888
    @abstractmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError
889

890
891
892
893
894
    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
        threads_per_block = 128    # Depends on the kernel implmentation
895

896
897
898
899
900
901
        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block
902
903

    @staticmethod
904
    def forward_abstract(logits_aval, scale_factor):
905
        """
906
        softmax_forward abstract
907
        """
908
909
910
911
912
913
914
915
916
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
917

918
919
        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval
920

921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
        i_aval, = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        pad_batch = batch
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits]
        operand_shapes = [i_shape]
939
940
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

941
942
943
944
        opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
                                                                k_seqlen,
                                                                jax_dtype_to_te_dtype(i_aval.dtype),
                                                                scale_factor)
945

946
        out = custom_caller(name, args, opaque, False)
947
948
949

        return [out]

950
951
952
953
954
955
956
957
    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output
958

959
960
961
962
963
964
965
966
    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
        logits, = batched_args
        logits_bdim, = batch_dims
967

968
969
        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims
970

971
972
973
974
975
976
977
978
979
    @staticmethod
    def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding
980
981

    @staticmethod
982
    def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
983
        """
984
        softmax_forward partitioning
985
        """
986
987
988
989
990
991
992
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        out_spec = logits_spec
        arg_shardings = (logits_spec,)
        out_shardings = out_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
993

994
995
996
997
998
999
1000
1001
1002
1003
    @staticmethod
    def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None):    # pylint: disable=unused-argument
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]
1004

1005
        assert dz_aval.shape == softmax_out_aval.shape
1006

1007
1008
        dx_aval = core.raise_to_shaped(softmax_out_aval)
        return dx_aval
1009
1010

    @staticmethod
1011
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
1012
        """
1013
        softmax_backward lowering rules
1014
        """
1015
        dz_aval, _ = ctx.avals_in
1016

1017
1018
        dz_type = ir.RankedTensorType(dz.type)
        dz_shape = dz_type.shape
1019

1020
1021
1022
1023
1024
1025
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, dz_shape[:-3])
        pad_batch = batch    # unused
        heads = dz_shape[-3]
        q_seqlen = dz_shape[-2]
        k_seqlen = dz_shape[-1]
1026

1027
1028
        softmax_out_type = ir.RankedTensorType(softmax_out.type)
        softmax_out_shape = softmax_out_type.shape
1029

1030
1031
1032
        out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)]
        operands = [dz, softmax_out]
        operand_shapes = [dz_shape, softmax_out_shape]
1033
1034
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1035
1036
1037
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype),
            scale_factor)
1038

1039
        out = custom_caller(name, args, opaque, False)
1040

1041
        return [out]
1042
1043

    @staticmethod
1044
    def backward_impl(primitive, dz, softmax_out, scale_factor):
1045
        """
1046
        softmax_backward implementation
1047
        """
1048
1049
1050
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx
1051

1052
1053
1054
1055
1056
1057
1058
1059
    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims
1060

1061
1062
        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
1063
1064

    @staticmethod
1065
    def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1066
        """
1067
        softmax_backward infer_sharding_from_operands
1068
        """
1069
1070
1071
1072
        del scale_factor, result_infos    # Unused.
        softmax_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec))
        return dx_sharding
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    @staticmethod
    def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos
        dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        dx_spec = softmax_out_spec
        arg_shardings = (dz_spec, softmax_out_spec)
        out_shardings = dx_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
1087
1088


1089
1090
1091
1092
1093
1094
1095
1096
1097
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
    name = "te_scaled_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None
1098

1099
1100
1101
1102
1103
    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False
1116

1117
1118
1119
1120
1121
1122
    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)
1123

1124
1125
1126
1127
1128
1129
1130
1131
1132
    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)
1133

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits,
                                             scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive,
                                                batched_args,
                                                batch_dims,
                                                scale_factor=scale_factor)
1146

1147
1148
1149
1150
1151
1152
1153
1154
1155
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor,
                                                  mesh, arg_infos, result_infos)
1156
1157


1158
register_primitive(ScaledSoftmaxFwdPrimitive)
1159

1160
1161

def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
1162
    """
1163
1164
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
1165
    """
1166
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
1167
1168


1169
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
1170
    """
1171
    Scaled Softmax Bwd Primitive
1172
    """
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
    name = "te_scaled_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)
1185
1186

    @staticmethod
1187
    def abstract(dz_aval, softmax_out_aval, scale_factor):
1188
        """
1189
        te_scaled_softmax_backward abstract
1190
        """
1191
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)
1192

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1203

1204
        return out
1205

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1225
1226

    @staticmethod
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor,
                                                   mesh, arg_infos, result_infos)


register_primitive(ScaledSoftmaxBwdPrimitive)


def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                       scale_factor: float) -> jnp.ndarray:
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                          softmax_out,
                                                          scale_factor=scale_factor)


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, mask_aval, scale_factor):    # pylint: disable=unused-argument
1276
        """
1277
        te_scaled_masked_softmax_forward abstract
1278
1279
        """

1280
1281
1282
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
1283

1284
1285
1286
1287
1288
1289
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
1290

1291
1292
1293
        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
1294
        ]
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
        assert pad_batch in (1, batch)    # 1 means broadcast
        assert mask_shape[-3] == 1    # 1 means broadcast
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """

        logits_aval, _ = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        mask_type = ir.RankedTensorType(mask.type)
        mask_shape = mask_type.shape
        pad_batch = reduce(operator.mul, mask_shape[:-3])

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits, mask]
        operand_shapes = [i_shape, mask_shape]
1327
1328
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1329
1330
1331
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype),
            scale_factor)
1332

1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
        out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

        return [out]

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits,
                                                                      mask,
                                                                      scale_factor=scale_factor)
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
        return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
            logits, mask, scale_factor=scale_factor), out_bdims

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = (logits_spec, mask_spec)
        out_shardings = logits_spec
        impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits,
                                                                mask,
                                                                scale_factor=scale_factor)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                                softmax_out,
                                                                scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
1498
1499
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
                                                  scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1584
1585
1586

        return out

1587
1588
1589
1590
1591
1592
1593
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor)
1594

1595
1596
1597
1598
1599
1600
1601
1602
    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)
1603

1604
1605
1606
1607
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1608

1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                                           scale_factor: float) -> jnp.ndarray:
1620
    """
1621
1622
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
1623
    """
1624
1625
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor)
1626
1627


1628
1629
@dataclass(frozen=True)
class FusedAttnHelper:
1630
    """
1631
    Helper for the fused attention backend
1632
    """
1633
1634
1635
1636
1637
1638
1639

    q_type: jnp.dtype
    kv_type: jnp.dtype
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
1640
1641
    num_heads_q: int
    num_heads_kv: int
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
    max_seqlen_q: int
    max_seqlen_kv: int
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
                                                             jax_dtype_to_te_dtype(self.kv_type),
                                                             self.qkv_layout, self.attn_bias_type,
                                                             self.attn_mask_type,
                                                             self.dropout_probability,
1657
                                                             self.num_heads_q, self.num_heads_kv,
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
                                                             self.max_seqlen_q, self.max_seqlen_kv,
                                                             self.head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
                f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(mask):
    """
    Generating cumsum seqlen for a batch
    """
    seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32)
    cu_seqlen = jnp.cumsum(seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))
    return cu_seqlen


class SelfFusedAttnFwdPrimitive(BasePrimitive):
    """
    Self Fused Attention Forward Primitive
    """
    name = "te_self_fused_attn_forward"
1717
    multiple_results = True
1718
1719
1720
    impl_static_args = (4, 5, 6, 7, 8)
    inner_primitive = None
    outer_primitive = None
1721
1722

    @staticmethod
1723
1724
    def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
1725
        """
1726
        Self fused attention fwd abstract
1727
        """
1728
1729
1730
1731
1732
1733
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del mask_or_cu_seqlen_aval, scaling_factor, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
        *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
        assert nqkv == 3
        assert qkv_aval.dtype == bias_aval.dtype
1734

1735
1736
        output_shape = (*batch_shape, max_seqlen, num_head, head_dim)
        output_dtype = qkv_dtype
1737

1738
        backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
1739
1740
                                  attn_mask_type, dropout_probability, num_head, num_head,
                                  max_seqlen, max_seqlen, head_dim).get_fused_attn_backend()
1741
1742
1743
1744
1745
1746
1747
1748

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
            softmax_dtype = qkv_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1)
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
1749
            raise ValueError(f'Unsupported {backend=}')
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype

        out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype)
        rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
        return out_aval, softmax_aux_aval, rng_state_aval
1761
1762

    @staticmethod
1763
1764
    def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1765
        """
1766
        Self fused attention fwd lowering rules
1767
        """
1768
        qkv_aval, _, _, _ = ctx.avals_in
1769

1770
1771
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1772

1773
1774
        operands = [qkv, bias, cu_seqlen, seed]
        operand_shapes = map(lambda x: x.type.shape, operands)
1775
        out_types = [
1776
1777
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1778
        ]
1779

1780
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
1781
1782
1783
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1784

1785
        out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
1786

1787
1788
1789
1790
1791
1792
        return out

    @staticmethod
    def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
             dropout_probability, is_training):
        assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
1793

1794
        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1795

1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
        output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
            qkv,
            bias,
            cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return output, softmax_aux, rng_state
1807

1808
1809
1810
1811
1812
1813
1814
1815
1816
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnFwdPrimitive.outer_primitive is not None
        qkv_bdim, _, _, seed_bdim = batch_dims

        out_bdims = qkv_bdim, qkv_bdim, seed_bdim
        return SelfFusedAttnFwdPrimitive.outer_primitive.bind(
1817
            *batched_args,
1818
1819
1820
1821
1822
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1823

1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
1836

1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
        impl = partial(SelfFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnFwdPrimitive)


def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray,
                        seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
1864
    """
1865
1866
    Wrapper for TE self fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
1867
    """
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
    return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
                                                          bias,
                                                          squeezed_mask,
                                                          seed,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
1883
1884


1885
class SelfFusedAttnBwdPrimitive(BasePrimitive):
1886
    """
1887
    Self Fused Attention Backward Primitive
1888
    """
1889
    name = "te_self_fused_attn_backward"
1890
    multiple_results = True
1891
    impl_static_args = (7, 8, 9, 10, 11)
1892
1893
    inner_primitive = None
    outer_primitive = None
1894
1895

    @staticmethod
1896
    def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
1897
1898
                 mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1899
        """
1900
        Self fused attention bwd abstract
1901
        """
1902
1903
        del softmax_aux_aval, rng_state_aval
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
1904
        del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
1905
1906
        del scaling_factor, dropout_probability, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
1907
1908
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
1909

1910
        dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
1911
1912
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
        return dqkv_aval, dbias_aval
1913
1914

    @staticmethod
1915
1916
    def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *,
                 attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
1917
        """
1918
        Self fused attention bwd lowering rules
1919
        """
1920
        qkv_aval, _, _, _, _, _, _ = ctx.avals_in
1921

1922
1923
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1924

1925
        operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
1926
        operand_shapes = map(lambda x: x.type.shape, operands)
1927
        out_types = [
1928
1929
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1930
        ]
1931

1932
1933
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1934
1935
1936
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1937

1938
        out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
1939
1940
1941

        return out

1942
    @staticmethod
1943
    def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
1944
1945
1946
1947
             attn_mask_type, scaling_factor, dropout_probability, is_training):
        assert SelfFusedAttnBwdPrimitive.inner_primitive is not None

        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1948

1949
1950
        dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
            qkv,
1951
            bias,
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
            softmax_aux,
            rng_state,
            output,
            doutput,
            cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return dqkv, dbias
1963

1964
1965
1966
1967
1968
1969
1970
1971
1972
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnBwdPrimitive.outer_primitive is not None
        qkv_bdim, *_ = batch_dims

        out_bdims = qkv_bdim, qkv_bdim
        return SelfFusedAttnBwdPrimitive.outer_primitive.bind(
1973
            *batched_args,
1974
1975
1976
1977
1978
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1979

1980
1981
1982
1983
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
1984
        del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
1985
1986
        del is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])
1987
        bias_spec = get_padded_spec(arg_infos[1])
1988
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
1989
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1990
1991
1992
1993
1994
1995
1996
        return (dx_sharding, dbias_sharding)

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
1997
        bias_spec = get_padded_spec(arg_infos[1])
1998
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
1999
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2000
2001
2002
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dx_sharding, dbias_sharding)

2003
        def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen):
2004
2005
            local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl(
                qkv,
2006
                bias,
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
                softmax_aux,
                rng_state,
                output,
                doutput,
                cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dx, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnBwdPrimitive)


2028
2029
2030
2031
2032
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
                        rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
                        squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
2033
    """
2034
2035
    Wrapper for TE self fused attention bwd
    Return the gradients of self fused attention with packed qkv input
2036
    """
2037
2038
2039
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
2040
    return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
2041
                                                          bias,
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
                                                          softmax_aux,
                                                          rng_state,
                                                          output,
                                                          doutput,
                                                          squeezed_mask,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
2052
2053


2054
class CrossFusedAttnFwdPrimitive(BasePrimitive):
2055
    """
2056
    Cross Fused Attention Forward Primitive
2057
    """
2058
    name = "te_cross_fused_attn_forward"
2059
    multiple_results = True
2060
    impl_static_args = (6, 7, 8, 9, 10)
2061
2062
    inner_primitive = None
    outer_primitive = None
2063
2064

    @staticmethod
2065
2066
2067
    def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval,
                 seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                 is_training):
2068
        """
2069
        Cross fused attention fwd abstract
2070
        """
2071
2072
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del scaling_factor, is_training
2073

2074
2075
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
2076

2077
2078
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
        *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
2079

2080
2081
2082
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)

        assert q_dtype == kv_dtype == bias_dtype
2083
2084
2085
2086
2087
2088
2089
2090
        assert q_batch_shape == kv_batch_shape
        assert q_num_head == kv_num_head
        assert q_head_dim == kv_head_dim
        assert nkv == 2
        assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype

        output_shape = q_aval.shape
        output_dtype = q_dtype
2091
2092

        backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
2093
2094
2095
                                  attn_bias_type, attn_mask_type, dropout_probability,
                                  q_num_head, kv_num_head,
                                  q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend()
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
            softmax_aux_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1)
            softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
            raise ValueError(f'Unsupported {backend=}')

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype
2111
2112
2113

        out_aval = q_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
2114
2115
2116
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)

        return out_aval, softmax_aux_aval, rng_state_aval
2117
2118

    @staticmethod
2119
2120
    def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2121
        """
2122
        Cross fused attention fwd lowering rules
2123
        """
2124
        q_aval, kv_aval, *_ = ctx.avals_in
2125
        assert q_aval.dtype == kv_aval.dtype
2126

2127
2128
2129
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2130

2131
        operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
2132
        operand_shapes = map(lambda x: x.type.shape, operands)
2133
        out_types = [
2134
2135
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2136
2137
        ]

2138
2139
2140
2141
2142
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2143

2144
        out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
2145
2146
2147

        return out

2148
    @staticmethod
2149
    def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
2150
2151
             scaling_factor, dropout_probability, is_training):
        assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
2152

2153
2154
        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
2155

2156
        output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
2157
2158
            q,
            kv,
2159
            bias,
2160
2161
2162
2163
2164
2165
2166
2167
            q_cu_seqlen,
            kv_cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2168
        return output, softmax_aux, rng_state
2169

2170
2171
2172
2173
2174
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
2175
        q_bdim, *_, seed_bdim = batch_dims
2176

2177
        out_bdims = q_bdim, q_bdim, seed_bdim
2178
        return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
2179
            *batched_args,
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
2197
2198
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
2209
2210
        rng_state_sharding = seed_sharding = NamedSharding(mesh,
                                                           PartitionSpec(get_all_mesh_axes(), None))
2211
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
2212
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
        impl = partial(CrossFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(CrossFusedAttnFwdPrimitive)


2225
2226
2227
2228
2229
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                         q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray,
                         seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2230
    """
2231
2232
    Wrapper for TE cross fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2233
    """
2234
2235
2236
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

2237
2238
2239
2240
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)

2241
2242
    return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
                                                           kv,
2243
                                                           bias,
2244
2245
2246
2247
2248
2249
2250
2251
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           seed,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2252
2253


2254
class CrossFusedAttnBwdPrimitive(BasePrimitive):
2255
    """
2256
    Cross Fused Attention Backward Primitive
2257
    """
2258
    name = "te_cross_fused_attn_backward"
2259
    multiple_results = True
2260
    impl_static_args = (9, 10, 11, 12, 13)
2261
2262
    inner_primitive = None
    outer_primitive = None
2263
2264

    @staticmethod
2265
2266
2267
    def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
                 doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2268
        """
2269
        Cross fused attention bwd abstract
2270
        """
2271
2272
        del softmax_aux_aval, rng_state_aval, output_aval
        del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training
2273
2274
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
2275
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
2276
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
2277
        assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
2278
        assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
2279

2280
2281
        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
2282
2283
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
        return dq_aval, dkv_aval, dbias_aval
2284
2285

    @staticmethod
2286
2287
2288
    def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                 kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
2289
        """
2290
        Cross fused attention bwd lowering rules
2291
        """
2292
        q_aval, kv_aval, *_ = ctx.avals_in
2293
        assert q_aval.dtype == kv_aval.dtype
2294

2295
2296
2297
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2298

2299
        operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
2300
        operand_shapes = map(lambda x: x.type.shape, operands)
2301
        out_types = [
2302
2303
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2304
        ]
2305

2306
2307
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2308
2309
2310
2311
2312
2313
        # the dropout elements are encoded in the forward auxiliary tensor
        # so seed is not needed in backward
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2314

2315
        out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
2316
2317
2318

        return out

2319
    @staticmethod
2320
2321
2322
    def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask,
             kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
             is_training):
2323
2324
2325
2326
2327
        assert CrossFusedAttnBwdPrimitive.inner_primitive is not None

        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)

2328
        dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
2329
2330
            q,
            kv,
2331
            bias,
2332
            softmax_aux,
2333
2334
            rng_state,
            output,
2335
2336
2337
2338
2339
2340
2341
2342
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2343
        return dq, dkv, dbias
2344
2345
2346
2347
2348
2349
2350
2351

    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, kv_bdim, *_ = batch_dims

2352
        out_bdims = q_bdim, kv_bdim, q_bdim
2353
        return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
2354
            *batched_args,
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
2369
        bias_spec = get_padded_spec(arg_infos[2])
2370
2371
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
2372
2373
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dkv_sharding, dbias_sharding)
2374
2375
2376
2377
2378
2379
2380

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
2381
        bias_spec = get_padded_spec(arg_infos[2])
2382
2383
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
2384
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2385
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
        out_shardings = (dq_sharding, dkv_sharding, dbias_sharding)

        def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                         kv_cu_seqlen):
            local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl(
                q,
                kv,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dq, local_dkv, global_dbias
2409

2410
        return mesh, sharded_impl, out_shardings, arg_shardings
2411

2412

2413
register_primitive(CrossFusedAttnBwdPrimitive)
2414
2415


2416
2417
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                         softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
2418
2419
2420
2421
                         doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
                         kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2422
    """
2423
2424
    Wrapper for TE cross fused attention bwd
    Return the gradients of cross fused attention with packed kv input
2425
    """
2426
2427
2428
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
2429
2430
    return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
                                                           kv,
2431
                                                           bias,
2432
                                                           softmax_aux,
2433
2434
                                                           rng_state,
                                                           output,
2435
2436
2437
2438
2439
2440
2441
2442
                                                           doutput,
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2443
2444


2445
class GatedGeluPrimitive(BasePrimitive):
2446
    """
2447
    Gated Gelu Froward Primitive
2448
    """
2449
    name = "te_gated_gelu"
2450
    multiple_results = False
2451
2452
2453
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2454
2455

    @staticmethod
2456
    def abstract(x_aval):
2457
        """
2458
        gated_gelu abstract
2459
        """
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        x_shape = x_aval.shape
        assert x_shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        x_shape = x_aval.shape
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)
2470

2471
        return out_aval
2472
2473

    @staticmethod
2474
    def lowering(ctx, x):
2475
        """
2476
        gated_gelu lowering rules
2477
        """
2478
2479
2480
2481
2482
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
2483

2484
2485
2486
2487
2488
2489
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
        ]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
2490

2491
2492
2493
2494
2495
        hidden_size = ir_x_shape[-1]
        batch_size = reduce(operator.mul, ir_x_shape[:-2])
        in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
                                                               in_dtype)
2496

2497
        out = custom_caller(GatedGeluPrimitive.name, args, opaque, False)
2498

2499
        return [out]
2500

2501
2502
2503
2504
2505
    @staticmethod
    def impl(x):
        assert GatedGeluPrimitive.inner_primitive is not None
        out = GatedGeluPrimitive.inner_primitive.bind(x)
        return out
2506

2507
2508
2509
2510
2511
2512
2513
2514
2515
    @staticmethod
    def batcher(batched_args, batch_dims):
        """
        gated_gelu batcher
        """
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluPrimitive.outer_primitive is not None
        inputs, = batched_args
        inputs_bdim, = batch_dims
2516

2517
2518
        out_bdims = inputs_bdim
        return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims
2519

2520
2521
2522
2523
2524
2525
2526
2527
2528
    @staticmethod
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
        """
        gated_gelu infer_sharding_from_operands
        """
        del result_infos    # Unused.
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding
2529

2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        gated_gelu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        impl = GatedGeluPrimitive.impl
        return mesh, impl, out_sharding, arg_shardings
2541
2542


2543
register_primitive(GatedGeluPrimitive)
2544
2545


2546
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray:
2547
    """
2548
2549
2550
    gated gelu wrapper
    Return FP8(geglu(inputs))
    Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
2551
    """
2552
    return GatedGeluPrimitive.outer_primitive.bind(inputs)
2553
2554


2555
class DgatedGeluPrimitive(BasePrimitive):
2556
    """
2557
    Dgated Gelu Primitive
2558
    """
2559
2560
2561
2562
2563
    name = "te_dgated_gelu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2564
2565

    @staticmethod
2566
    def abstract(dz_aval, x_aval):
2567
        """
2568
        dgated_gelu abstract
2569
        """
2570
2571
2572
2573
2574
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
2575

2576
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
2577

2578
2579
2580
2581
2582
        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)
        return out_aval
2583
2584

    @staticmethod
2585
    def lowering(ctx, dz, x):
2586
        """
2587
        dgated_gelu lowering rules
2588
        """
2589
2590
2591
2592
2593
2594
2595
2596
2597
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
        ir_in_type = ir.RankedTensorType(dz.type)
        ir_in_shape = ir_in_type.shape
        gi_type = ir.RankedTensorType(x.type)
        gi_shape = gi_type.shape
        for axis in range(len(ir_in_shape) - 1):
            assert ir_in_shape[axis] == gi_shape[axis]
2598

2599
2600
2601
2602
2603
2604
        ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
        i_hidden_size = ir_in_shape[-1]
        g_hidden_size = gi_shape[-1]
        assert i_hidden_size == g_hidden_size
        out_dtype = ir_in_type.element_type
        out_shape = gi_shape
2605
2606

        out_types = [
2607
            ir.RankedTensorType.get(out_shape, out_dtype),
2608
        ]
2609
2610
        operands = [dz, x]
        operand_shapes = [ir_in_shape, gi_shape]
2611
2612
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2613
2614
2615
        in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
                                                               in_dtype, in_dtype)
2616

2617
        out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False)
2618
2619
2620
2621

        return [out]

    @staticmethod
2622
2623
2624
2625
2626
2627
2628
    def impl(dz, x):
        """
        dgated_gelu implementation
        """
        assert DgatedGeluPrimitive.inner_primitive is not None
        dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x)
        return dx
2629
2630

    @staticmethod
2631
    def batcher(batched_args, batch_dims):
2632
        """
2633
        dgated_gelu batcher
2634
        """
2635
2636
2637
2638
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims
2639

2640
2641
        out_bdims = x_bdim
        return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
2642
2643

    @staticmethod
2644
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
2645
        """
2646
        dgated_gelu infer_sharding_from_operands
2647
        """
2648
2649
2650
2651
        del result_infos    # Unused.
        gelu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
        return dx_sharding
2652

2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        dgated_gelu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding
        impl = DgatedGeluPrimitive.impl
        return mesh, impl, out_shardings, arg_shardings
2664
2665


2666
register_primitive(DgatedGeluPrimitive)
2667
2668


2669
2670
2671
2672
2673
2674
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
    """
    dgated_gelu fusion wrapper
    Return dgeglu(inputs)
    """
    return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
2675
2676


2677
2678
def _normalize_axis_boundary(axis, ndim):
    return axis if axis >= 0 else ndim + axis
2679
2680


2681
def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
2682
    """
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
2701
    """
2702
2703
2704
2705
2706
2707
2708
2709
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes
    assert static_axis_boundary < len(shape) - 2    # at least 2 remaining for transpose.
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
    return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:],
            *shape[transpose_start_idx:transpose_axis_boundary])
2710
2711


2712
class CastTransposePrimitive(BasePrimitive):
2713
    """
2714
    Cast Transpose Primitive
2715
    """
2716
2717
2718
2719
2720
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None
2721
2722

    @staticmethod
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval
2742
2743

    @staticmethod
2744
2745
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
2746
        """
2747
        te_cast_transpose_p lowering rules
2748
        """
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(CastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 2})

        return out
2791
2792

    @staticmethod
2793
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
2794
        """
2795
        te_cast_transpose implementation
2796
        """
2797
2798
2799
2800
2801
2802
2803
        assert CastTransposePrimitive.inner_primitive is not None
        casted_x, casted_transposed_x, updated_amax = \
            CastTransposePrimitive.inner_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary)
        return casted_x, casted_transposed_x, updated_amax
2804

2805
2806
2807
2808
2809
2810
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
                transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2811

2812
2813
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims
2814

2815
2816
2817
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2818

2819
2820
2821
2822
2823
2824
2825
2826
2827
        out_bdims = x_bdim, x_bdim, amax_bdim
        return CastTransposePrimitive.outer_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2828

2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
                                     arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                  result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
            local_cx, local_cxt, local_updated_amax = \
                CastTransposePrimitive.impl(x, amax, scale, scale_inv,
                     out_dtype=out_dtype,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype, static_axis_boundary: int,
                   transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
2871
    """
2872
2873
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
2874
    """
2875
2876
2877
2878
2879
2880
2881
2882
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary)
2883
2884


2885
class TransposePrimitive(BasePrimitive):
2886
    """
2887
    Transpose Primitive
2888
    """
2889
    name = "te_transpose"
2890
    multiple_results = False
2891
2892
2893
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None
2894
2895

    @staticmethod
2896
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
2897
        """
2898
        _transpose abstract
2899
        """
2900
2901
2902
        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
2903

2904
        return xt_aval
2905
2906

    @staticmethod
2907
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
2908
        """
2909
        _transpose cuda lowering
2910
2911
        """

2912
2913
2914
2915
        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
            jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2
        ]
2916

2917
2918
2919
2920
2921
2922
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
2923

2924
2925
2926
2927
2928
2929
        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
        operands = [x]
        operand_shapes = [ir_x_shape]
2930
2931
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2932
2933
2934
2935
2936
        te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype,
                                                               te_dtype)
2937

2938
        out = custom_caller(TransposePrimitive.name, args, opaque, False)
2939
2940
2941

        return [out]

2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
        transposed_x = \
            TransposePrimitive.inner_primitive.bind(x,
                                                    static_axis_boundary=static_axis_boundary,
                                                    transpose_axis_boundary=transpose_axis_boundary)
        return transposed_x
2953

2954
2955
2956
2957
2958
    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2959

2960
2961
        x, = batched_args
        x_bdim, = batch_dims
2962

2963
2964
2965
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2966

2967
2968
2969
2970
        out_bdims = x_bdim
        return TransposePrimitive.outer_primitive.bind(
            x, static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2971
2972

    @staticmethod
2973
2974
2975
2976
2977
2978
2979
    def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding
2980
2981

    @staticmethod
2982
2983
2984
2985
2986
2987
2988
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding
2989

2990
2991
2992
        impl = partial(TransposePrimitive.impl,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
2993

2994
        return mesh, impl, out_shardings, arg_shardings
2995
2996


2997
register_primitive(TransposePrimitive)
2998
2999


3000
3001
def transpose(x: jnp.ndarray, static_axis_boundary: int,
              transpose_axis_boundary: int) -> jnp.ndarray:
3002
    """
3003
    transpose wrapper
3004
    """
3005
3006
3007
    return TransposePrimitive.outer_primitive.bind(x,
                                                   static_axis_boundary=static_axis_boundary,
                                                   transpose_axis_boundary=transpose_axis_boundary)
3008
3009


3010
class LayerNormFwdFp8Primitive(BasePrimitive):
3011
    """
3012
    Layer Normalization Forward FP8 Primitive
3013
    """
3014
3015
3016
3017
3018
    name = "te_layernorm_forward_fp8"
    multiple_results = True
    impl_static_args = (6, 7, 8)    # out_type, zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
3019
3020

    @staticmethod
3021
3022
    def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 zero_centered_gamma, epsilon):
3023
        """
3024
        LayerNorm fwd (fp8 out) abstract
3025
        """
3026
3027
        del zero_centered_gamma, epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3028

3029
3030
3031
3032
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3033

3034
3035
3036
3037
3038
3039
3040
3041
3042
        mu_rsigama_dtype = jnp.float32

        assert gamma_aval.size == beta_aval.size

        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, mu_aval, rsigma_aval, updated_amax_aval
3043
3044

    @staticmethod
3045
3046
    def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma,
                 epsilon):
3047
        """
3048
        LayerNorm fwd (fp8 out) lowering rules
3049
        """
3050
        x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3051

3052
3053
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3054

3055
3056
3057
3058
3059
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gamma_aval.dtype == beta_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3060

3061
3062
3063
3064
3065
3066
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
3067

3068
3069
        assert g_type == b_type
        assert g_shape == b_shape
3070

3071
3072
3073
3074
3075
3076
3077
3078
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3079

3080
3081
3082
3083
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3084

3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, beta, amax, scale, scale_inv]
        operand_shapes = [
            x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3096

3097
3098
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3099
3100
3101
3102
3103
3104
3105
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
3106
            sm_margin,
3107
        )
3108

3109
3110
3111
3112
3113
        out = custom_caller(LayerNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={3: 3})
3114

3115
        return out
3116
3117

    @staticmethod
3118
    def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon):
3119
        """
3120
        to describe implementation
3121
        """
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
        assert LayerNormFwdFp8Primitive.inner_primitive is not None
        out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
        return out, mu, rsigma, updated_amax
3134
3135

    @staticmethod
3136
    def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon):
3137
        """
3138
        to describe batch rules for vmap
3139
        """
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, beta, amax, scale, scale_inv = batched_args
        x_bdim, _, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
        return LayerNormFwdFp8Primitive.outer_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos,
                                     result_infos):
        del out_dtype, zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance.")

        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
3193

3194
3195
3196
3197
3198
3199
3200
        def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
            local_x, local_mu, local_rsigma, local_amax = \
                LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv,
                                            out_dtype=out_dtype,
                                            zero_centered_gamma=zero_centered_gamma,
                                            epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3201

3202
            return local_x, local_mu, local_rsigma, global_updated_amax
3203

3204
        return mesh, sharded_impl, out_shardings, arg_shardings
3205

3206
3207
3208
3209
3210
3211
3212

register_primitive(LayerNormFwdFp8Primitive)


def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
                      scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype,
                      zero_centered_gamma: bool, epsilon: float):
3213
    """
3214
    Wrapper for TE layernorm fwd (fp8 out)
3215
    """
3216
3217
3218
3219
3220
3221
3222
3223
3224
    return LayerNormFwdFp8Primitive.outer_primitive.bind(x,
                                                         gamma,
                                                         beta,
                                                         amax,
                                                         scale,
                                                         scale_inv,
                                                         out_dtype=out_dtype,
                                                         zero_centered_gamma=zero_centered_gamma,
                                                         epsilon=epsilon)
3225
3226


3227
class RmsNormFwdFp8Primitive(BasePrimitive):
3228
    """
3229
    RMS Normalization Forward FP8 Primitive
3230
    """
3231
3232
3233
3234
3235
    name = "te_rmsnorm_forward_fp8"
    multiple_results = True
    impl_static_args = (5, 6)    # out_dtype, epsilon
    inner_primitive = None
    outer_primitive = None
3236

3237
3238
    @staticmethod
    def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
3239
        """
3240
        RMSNorm fwd (fp8 out) abstract
3241
        """
3242
3243
        del epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3244

3245
3246
3247
3248
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3249

3250
3251
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
3252

3253
        rsigama_dtype = jnp.float32
3254

3255
3256
3257
        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
        amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3258

3259
        return out_aval, rsigma_aval, amax_aval
3260
3261

    @staticmethod
3262
    def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
3263
        """
3264
        RMSNorm fwd (fp8 out) lowering rules
3265
3266
        """

3267
3268
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3269

3270
        x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3271

3272
3273
3274
3275
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3276

3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3294

3295
3296
3297
3298
3299
3300
3301
3302
3303
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, amax, scale, scale_inv]
        operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

3304
3305
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3306
3307
3308
3309
3310
3311
3312
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
3313
            sm_margin,
3314
3315
        )

3316
3317
3318
3319
3320
3321
3322
3323
        out = custom_caller(RmsNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out

3324
    @staticmethod
3325
    def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
3326
        """
3327
        to describe implementation
3328
        """
3329
3330
3331
3332
3333
3334
3335
3336
3337
        assert RmsNormFwdFp8Primitive.inner_primitive is not None
        out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
                                                                        gamma,
                                                                        amax,
                                                                        scale,
                                                                        scale_inv,
                                                                        out_dtype=out_dtype,
                                                                        epsilon=epsilon)
        return out, rsigma, amax
3338

3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
        """
        to describe batch rules for vmap
        """
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
        out_bdims = x_bdim, x_bdim, amax_bdim
        return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                           gamma,
                                                           amax,
                                                           scale,
                                                           scale_inv,
                                                           out_dtype=out_dtype,
                                                           epsilon=epsilon), out_bdims
3356

3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
    @staticmethod
    def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del out_dtype, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, rsigma_sharding, amax_sharding)
3371

3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
    @staticmethod
    def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
3390

3391
3392
3393
3394
3395
        def sharded_impl(x, gamma, amax, scale, scale_inv):
            local_x, local_rsigma, local_amax= \
                RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv,
                                            out_dtype=out_dtype, epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3396

3397
            return local_x, local_rsigma, global_updated_amax
3398

3399
        return mesh, sharded_impl, out_shardings, arg_shardings
3400
3401


3402
register_primitive(RmsNormFwdFp8Primitive)
3403

3404
3405
3406

def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                    scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float):
3407
    """
3408
    Wrapper for TE rmsnorm fwd (fp8 out)
3409
    """
3410
3411
3412
3413
3414
3415
3416
    return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                       gamma,
                                                       amax,
                                                       scale,
                                                       scale_inv,
                                                       out_dtype=out_dtype,
                                                       epsilon=epsilon)
3417
3418


3419
class GatedGeluFp8Primitive(BasePrimitive):
3420
    """
3421
    Gated Gelu FP8 Primitive
3422
    """
3423
    name = "te_gated_gelu_fp8"
3424
    multiple_results = True
3425
3426
3427
    impl_static_args = (4,)    #out_dtype
    inner_primitive = None
    outer_primitive = None
3428
3429

    @staticmethod
3430
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
3431
        """
3432
        te_gated_gelu_p abstract
3433
        """
3434
3435
3436
3437
3438
3439
3440
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3441

3442
3443
3444
3445
3446
3447
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3448

3449
        return out_aval, updated_amax_aval
3450
3451

    @staticmethod
3452
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
3453
        """
3454
        te_gated_gelu_p lowering rules
3455
        """
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3469

3470
3471
3472
3473
        hidden_size = ir_x_shape[-1]
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
3474
        out_types = [
3475
3476
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
3477
        ]
3478
3479
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
3480
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3481

3482
3483
3484
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))
3485

3486
3487
3488
3489
3490
        out = custom_caller(GatedGeluFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 1})
3491
3492
3493
3494

        return out

    @staticmethod
3495
    def impl(x, amax, scale, scale_inv, out_dtype):
3496
        """
3497
        to describe implementation
3498
        """
3499
3500
3501
3502
3503
3504
3505
        assert GatedGeluFp8Primitive.inner_primitive is not None
        out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x,
                                                                       amax,
                                                                       scale,
                                                                       scale_inv,
                                                                       out_dtype=out_dtype)
        return out, updated_amax
3506
3507

    @staticmethod
3508
    def batcher(batched_args, batch_dims, *, out_dtype):
3509
        """
3510
        to describe batch rules for vmap
3511
        """
3512
3513
3514
3515
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims
3516

3517
3518
3519
3520
3521
3522
        out_bdims = x_bdim, amax_bdim
        return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                          amax,
                                                          scale,
                                                          scale_inv,
                                                          out_dtype=out_dtype), out_bdims
3523

3524
3525
3526
3527
3528
3529
3530
    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)
3531

3532
3533
3534
3535
3536
3537
3538
3539
    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)
3540

3541
3542
3543
3544
3545
3546
3547
        def sharded_impl(x, amax, scale, scale_inv):
            local_x, local_amax = GatedGeluFp8Primitive.impl(x,
                                                             amax,
                                                             scale,
                                                             scale_inv,
                                                             out_dtype=out_dtype)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3548

3549
            return local_x, global_updated_amax
3550

3551
        return mesh, sharded_impl, out_shardings, arg_shardings
3552
3553


3554
register_primitive(GatedGeluFp8Primitive)
3555

3556
3557
3558

def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3559
    """
3560
3561
    gated gelu wrapper
    Return FP8(geglu(x))
3562
    """
3563
3564
3565
3566
3567
    return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                      amax,
                                                      scale,
                                                      scale_inv,
                                                      out_dtype=out_dtype)
3568
3569


3570
class DgatedGeluCastTransposePrimitive(BasePrimitive):
3571
    """
3572
    Dgated Gelu Cast Transpose Primitive
3573
    """
3574
    name = "te_dgated_gelu_cast_transpose"
3575
    multiple_results = True
3576
3577
3578
    impl_static_args = (5, 6)    # out_dtype, static_axis_boundary
    inner_primitive = None
    outer_primitive = None
3579
3580

    @staticmethod
3581
3582
    def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 static_axis_boundary):
3583
        """
3584
        te_dgated_gelu_cast_transpose_p abstract
3585
        """
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert x_aval.shape[-2] == 2    # Linear + GeLU
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval
3601

3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
        """
        te_dgated_gelu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        assert x_shape[-2] == 2    # Linear + GeLU
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(dz_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(DgatedGeluCastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out
3651
3652

    @staticmethod
3653
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
3654
        """
3655
        to describe implementation
3656
        """
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
        assert DgatedGeluCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary)
        return out, t_out, updated_amax
3667

3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
3678

3679
3680
3681
3682
        out_bdims = x_bdim, x_bdim, amax_bdim
        return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
            dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
            static_axis_boundary=x_bdim), out_bdims
3683

3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)
3694

3695
3696
3697
3698
3699
3700
3701
    @staticmethod
    def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
3702

3703
3704
3705
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
3706

3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_updated_amax
3718

3719
        return mesh, sharded_impl, out_shardings, arg_shardings
3720
3721


3722
register_primitive(DgatedGeluCastTransposePrimitive)
3723

3724
3725
3726
3727
3728

def dgated_gelu_cast_transpose(
        dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
        scale_inv: jnp.ndarray, out_dtype: TEDType,
        static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3729
    """
3730
3731
    cast transpose d_gated_gelu fusion wrapper
    Return FP8(dgeglu(inputs))
3732
    """
3733
3734
3735
3736
3737
3738
3739
3740
    return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary)