cpp_extensions.py 148 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te custom call"""

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
11
import os
12
13
import warnings

14
15
16
17
18
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.interpreters import xla, mlir
19
from jax.experimental.custom_partitioning import custom_partitioning
20
from jax.interpreters.mlir import ir, dtype_to_ir_type
21
22
from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
23

24
25
26
27
28
29
30
try:
    from jaxlib.hlo_helpers import custom_call
except ImportError:
    # Newer JAX changed its API. But we want to support a few JAX
    # version, so we still need this import.
    pass

31
32
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
33
34
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
35
36
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
37

38
39
40
41
42
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec

43
44
45
46
47
48
49
50
51
for _name, _value in transformer_engine_jax.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="CUDA")


def te_dtype_to_jax_dtype(te_dtype):
    """
    convert TE dtype to jax dtype
    """
    assert isinstance(te_dtype, TEDType)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    converter = {
        TEDType.kFloat32: jnp.float32,
        TEDType.kFloat16: jnp.float16,
        TEDType.kBFloat16: jnp.bfloat16,
        TEDType.kInt32: jnp.int32,
        TEDType.kInt64: jnp.int64,
        TEDType.kFloat8E4M3: jnp.float8_e4m3fn,
        TEDType.kFloat8E5M2: jnp.float8_e5m2,
    }

    if te_dtype not in converter:
        raise ValueError(f"Unsupported {te_dtype=}")

    return converter.get(te_dtype)
67
68
69
70
71
72
73
74
75


def te_dtype_to_ir_dtype(te_dtype):
    """
    convert TE dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))


76
77
78
79
80
81
82
def jax_dtype_to_ir_dtype(jax_dtype):
    """
    convert Jax dtype to MLIR dtype
    """
    return dtype_to_ir_type(np.dtype(jax_dtype))


83
84
85
86
def jax_dtype_to_te_dtype(jax_dtype):
    """
    convert jax dtype to TE dtype
    """
87
    jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
88

89
90
91
92
93
94
95
96
97
    converter = {
        jnp.float32.dtype: TEDType.kFloat32,
        jnp.float16.dtype: TEDType.kFloat16,
        jnp.bfloat16.dtype: TEDType.kBFloat16,
        jnp.int32.dtype: TEDType.kInt32,
        jnp.int64.dtype: TEDType.kInt64,
        jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
        jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
    }
98

99
100
    if jax_dtype not in converter:
        raise ValueError(f"Unsupported {jax_dtype=}")
101

102
    return converter.get(jax_dtype)
103
104


105
106
107
108
109
110
111
112
def get_padded_spec(arg_info):
    """
    Get padded spec for partitioning from arguments' information
    """
    if arg_info.sharding is None:
        return te_get_padded_spec(None, arg_info.ndim)
    ndim, spec = arg_info.ndim, arg_info.sharding.spec
    return te_get_padded_spec(spec, ndim)
113
114


115
def _check_valid_batch_dims(bdims):
116
    """
117
    Assert out non-supported bath dims
118
    """
119
120
121
122
    for dim in bdims:
        assert dim in [0, None], \
            "Currently only support batch_dim in [0, None], " \
            f"but got {dim=}"
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145


class BasePrimitive(metaclass=ABCMeta):
    """
    jax premitive
    """

    @staticmethod
    @abstractmethod
    def abstract():
        """
        to describe computing graph
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def lowering():
        """
        to describe MLIR
        """
        return NotImplemented

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    @staticmethod
    @abstractmethod
    def impl():
        """
        to describe implementation
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def batcher():
        """
        to describe batch rules for vmap
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def infer_sharding_from_operands():
        """
        to describe infer_sharding_from_operands for custom_partitioning
        """
        return NotImplemented

    @staticmethod
    @abstractmethod
    def partition():
        """
        to describe partition for custom_partitioning
        """
        return NotImplemented

178
179
180
181
182

def register_primitive(cls):
    """
    register jax primitive
    """
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251


@dataclass
class CustomCallArgsWrapper:
    """
    wrapper of XLA custom call args
    """

    def __init__(self,
                 output_types,
                 operands,
                 operand_shapes,
                 operand_specific_layouts=None,
                 output_specific_layouts=None):
        self.output_types = output_types
        self.operands = operands
        self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
                                                                      operand_specific_layouts)
        output_shapes = [x.shape for x in output_types]
        self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
                                                                     output_specific_layouts)

    @staticmethod
    def generate_layouts(shapes, specific_layouts):
        """
        setup layouts for XLA custom call
        """

        def default_layout(shape):
            return range(len(shape) - 1, -1, -1)

        if specific_layouts is None:
            specific_layouts = {}

        layouts = []
        for idx, shape in enumerate(shapes):
            if idx in specific_layouts:
                layouts.append(specific_layouts[idx])
            else:
                layouts.append(default_layout(shape))
        return layouts


def custom_caller(name, args, opaque, has_side_effect, **kwargs):
    """
    XLA custom call warpper
    """
252
253
254
255
256
257
258
259
260
261
262
263
264
    if hasattr(mlir, "custom_call"):
        out = mlir.custom_call(name,
                               result_types=args.output_types,
                               operands=args.operands,
                               operand_layouts=args.operand_layouts,
                               result_layouts=args.output_layouts,
                               backend_config=opaque,
                               has_side_effect=has_side_effect,
                               **kwargs).results
    else:
        # Need to disable one pylint error as the second function
        # parameter name recenctly in JAX. Otherwise we won't be
        # compatible with multiple JAX version.
265
266
267
268
269
270
271
272
273
        out = custom_call(    # pylint: disable=too-many-function-args
            name,
            args.output_types,
            operands=args.operands,
            operand_layouts=args.operand_layouts,
            result_layouts=args.output_layouts,
            backend_config=opaque,
            has_side_effect=has_side_effect,
            **kwargs)
274
275
276
    return out


277
class LayerNormFwdPrimitive(BasePrimitive):
278
    """
279
    Layer Normalization Forward Primitive
280
    """
281
282
283
284
285
    name = "te_layernorm_forward"
    multiple_results = True
    impl_static_args = (3, 4)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
286
287

    @staticmethod
288
    def abstract(x_aval, gamma_aval, beta_aval, **kwargs):    # pylint: disable=unused-argument
289
        """
290
        LayerNorm fwd abstract
291
        """
292
293
294
295
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        mu_rsigama_dtype = jnp.float32
296

297
298
        out_aval = core.raise_to_shaped(x_aval)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
299

300
301
302
303
304
        assert gamma_aval.size == beta_aval.size
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0

        return out_aval, mu_aval, rsigma_aval
305
306

    @staticmethod
307
    def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
308
        """
309
        LayerNorm fwd lowering rules
310
        """
311
312
313
314
315
316
317
318
        x_aval, gamma_aval, beta_aval = ctx.avals_in
        assert gamma_aval.dtype == beta_aval.dtype
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
319

320
321
        assert g_type == b_type
        assert g_shape == b_shape
322

323
324
325
326
327
        # Output shape is same as the input shape, but the output type is same as the weight type.
        # See ln_api.cpp
        output_type = g_type.element_type
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
328

329
330
331
332
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
333

334
335
336
337
338
339
340
341
        out_types = [
            ir.RankedTensorType.get(out_shape, output_type),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
        ]
        operands = [x, gamma, beta]
        operand_shapes = [x_shape, g_shape, b_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
342

343
344
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

345
346
347
348
349
350
351
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
352
            sm_margin,
353
        )
354

355
        out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
356

357
        return out
358
359

    @staticmethod
360
    def impl(x, gamma, beta, zero_centered_gamma, epsilon):
361
        """
362
        to describe implementation
363
        """
364
365
366
367
368
369
370
        assert LayerNormFwdPrimitive.inner_primitive is not None
        out, mu, rsigma = LayerNormFwdPrimitive.inner_primitive.bind(
            x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return out, mu, rsigma

    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
371
        """
372
        to describe batch rules for vmap
373
        """
374
375
376
377
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdPrimitive.outer_primitive is not None
        x, gamma, beta = batched_args
        x_bdim, _, _ = batch_dims
378

379
380
381
382
383
384
        out_bdims = x_bdim, x_bdim, x_bdim
        return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                          gamma,
                                                          beta,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
385

386
387
388
389
390
391
392
393
394
395
396
397
398
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, mu_sharding, rsigma_sharding)
399

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec, b_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
415

416
417
418
419
420
421
        arg_shardings = (x_sharding, g_sharding, b_sharding)
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding)
        impl = partial(LayerNormFwdPrimitive.impl,
                       zero_centered_gamma=zero_centered_gamma,
                       epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
422
423


424
register_primitive(LayerNormFwdPrimitive)
425
426


427
428
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
                  epsilon: float):
429
    """
430
    Wrapper for TE layernorm fwd
431
    """
432
433
434
435
436
    return LayerNormFwdPrimitive.outer_primitive.bind(x,
                                                      gamma,
                                                      beta,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
437
438


439
class LayerNormBwdPrimitive(BasePrimitive):
440
    """
441
    Layer Normalization Backward Primitive
442
    """
443
444
445
446
447
    name = "te_layernorm_backward"
    multiple_results = True
    impl_static_args = (5, 6)    # zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
448
449

    @staticmethod
450
    def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
451
        """
452
        Layernorm bwd abstract
453
        """
454
455
456
457
458
459
460
461
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
        assert mu_dtype == rsigma_dtype == jnp.float32
462

463
464
465
        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval, dbeta_aval
466
467

    @staticmethod
468
    def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):
469
        """
470
        Layernorm bwd lowering rules
471
        """
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        _, x_aval, _, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(gamma.type)
        b_shape = b_type.shape
        assert g_type == b_type
        assert g_shape == b_shape

        dz_shape = ir.RankedTensorType(dz.type).shape
        mu_shape = ir.RankedTensorType(mu.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
488
489

        out_types = [
490
491
492
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
            ir.RankedTensorType.get(b_shape, b_type.element_type),
493
        ]
494
495
        operands = [dz, mu, rsigma, x, gamma]
        operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
496
497
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

498
499
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

500
501
502
503
504
505
506
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
507
            sm_margin,
508
        )
509

510
        out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
511

512
        return out
513

514
515
516
517
518
519
    @staticmethod
    def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon):
        assert LayerNormBwdPrimitive.inner_primitive is not None
        dx, dgamma, dbeta = LayerNormBwdPrimitive.inner_primitive.bind(
            dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon)
        return dx, dgamma, dbeta
520

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    @staticmethod
    def batcher(batched_args, batch_dims, *, zero_centered_gamma, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert LayerNormBwdPrimitive.outer_primitive is not None
        dz, x, mu, rsigma, gamma = batched_args
        _, x_bdim, _, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim, gamma_bdim
        return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                          x,
                                                          mu,
                                                          rsigma,
                                                          gamma,
                                                          zero_centered_gamma=zero_centered_gamma,
                                                          epsilon=epsilon), out_bdims
536

537
538
539
540
541
542
543
544
545
546
547
548
549
550
    @staticmethod
    def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        return dx_sharding, dgamma_sharding, dbeta_sharding
551

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    @staticmethod
    def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_b_spec = get_padded_spec(arg_infos[4])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec))
        out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2
        arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec)))

        def sharded_impl(dz, x, mu, rsigma, gamma):
            local_dx, local_dgamma, local_dbeta = \
                LayerNormBwdPrimitive.impl(dz, x, mu, rsigma, gamma,
                     zero_centered_gamma=zero_centered_gamma,
                     epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
            return local_dx, global_dgamma, global_dbeta

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(LayerNormBwdPrimitive)


def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray,
                  gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
587
    """
588
    Wrapper for TE layernorm bwd
589
    """
590
591
592
593
594
595
596
    return LayerNormBwdPrimitive.outer_primitive.bind(dz,
                                                      x,
                                                      mu,
                                                      rsigma,
                                                      gamma,
                                                      zero_centered_gamma=zero_centered_gamma,
                                                      epsilon=epsilon)
597
598


599
class RmsNormFwdPrimitive(BasePrimitive):
600
    """
601
    RMS Normalization Forward Primitive
602
    """
603
    name = "te_rmsnorm_forward"
604
    multiple_results = True
605
606
607
    impl_static_args = (2,)    # epsilon
    inner_primitive = None
    outer_primitive = None
608
609

    @staticmethod
610
    def abstract(x_aval, gamma_aval, **kwargs):    # pylint: disable=unused-argument
611
        """
612
        RMSNorm fwd abstract
613
        """
614
615
616
617
618
619
620
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]

        rsigama_dtype = jnp.float32

        out_aval = core.raise_to_shaped(x_aval)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
621

622
623
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
624

625
        return out_aval, rsigma_aval
626
627

    @staticmethod
628
    def lowering(ctx, x, gamma, *, epsilon):
629
        """
630
        RMSNorm fwd lowering rules
631
        """
632
633
634
635
636
637
638
639
640
641
642
        x_aval, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        rsigma_element_type = ir.F32Type.get()

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
643
644

        out_types = [
645
646
            ir.RankedTensorType.get(out_shape, x_type.element_type),
            ir.RankedTensorType.get(batch_shape, rsigma_element_type),
647
        ]
648
649
        operands = [x, gamma]
        operand_shapes = [x_shape, g_shape]
650
651
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

652
653
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

654
655
656
657
658
659
660
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
661
            sm_margin,
662
        )
663

664
        out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
665
666
667
668

        return out

    @staticmethod
669
    def impl(x, gamma, epsilon):
670
        """
671
        to describe implementation
672
        """
673
674
675
        assert RmsNormFwdPrimitive.inner_primitive is not None
        out, rsigma = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon)
        return out, rsigma
676
677

    @staticmethod
678
    def batcher(batched_args, batch_dims, *, epsilon):
679
        """
680
        to describe batch rules for vmap
681
        """
682
683
684
685
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdPrimitive.outer_primitive is not None
        x, gamma = batched_args
        x_bdim, _ = batch_dims
686

687
688
        out_bdims = x_bdim, x_bdim
        return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon), out_bdims
689

690
691
692
693
694
695
696
697
698
699
700
701
702
    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        return (out_sharding, rsigma_sharding)
703

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec, g_spec = map(get_padded_spec, arg_infos)
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (x_sharding, g_sharding)
        out_shardings = (out_sharding, rsigma_sharding)
        impl = partial(RmsNormFwdPrimitive.impl, epsilon=epsilon)
        return mesh, impl, out_shardings, arg_shardings
722
723


724
register_primitive(RmsNormFwdPrimitive)
725
726


727
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
728
    """
729
    Wrapper for TE rmsnorm fwd
730
    """
731
    return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)
732
733


734
class RmsNormBwdPrimitive(BasePrimitive):
735
    """
736
    RMS Normalization Backward Primitive
737
    """
738
    name = "te_rmsnorm_backward"
739
    multiple_results = True
740
741
742
    impl_static_args = (4,)    # epsilon
    inner_primitive = None
    outer_primitive = None
743
744

    @staticmethod
745
746
747
748
749
750
751
    def abstract(
            dz_aval,
            x_aval,
            rsigma_aval,
            gamma_aval,
            **kwargs    # pylint: disable=unused-argument
    ):
752
        """
753
        RMSNorm bwd abstract
754
        """
755
756
757
758
759
760
761
762
763
764
765
766
767
768
        w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype)
        rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype)

        assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype
        assert dz_aval.shape == x_aval.shape
        assert rsigma_aval.shape == x_aval.shape[:-1]
        assert rsigma_dtype == jnp.float32

        dx_aval = core.raise_to_shaped(dz_aval)
        dgamma_aval = core.raise_to_shaped(gamma_aval)
        return dx_aval, dgamma_aval

    @staticmethod
    def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
769
        """
770
        RMSNorm bwd lowering rules
771
        """
772
773
774
775
776
777
778
779
780
781
        _, x_aval, _, gamma_aval = ctx.avals_in
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        dz_shape = ir.RankedTensorType(dz.type).shape
        rsigma_shape = ir.RankedTensorType(rsigma.type).shape

        hidden_size = reduce(operator.mul, g_shape)
        batch_size = reduce(operator.mul, x_shape) // hidden_size
782
783

        out_types = [
784
785
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(g_shape, g_type.element_type),
786
        ]
787
788
        operands = [dz, rsigma, x, gamma]
        operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
789
790
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

791
792
        sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

793
794
795
796
797
798
799
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
800
            sm_margin,
801
        )
802

803
        out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
804
805
806

        return out

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
    @staticmethod
    def impl(dz, x, rsigma, gamma, epsilon):
        assert RmsNormBwdPrimitive.inner_primitive is not None
        dx, dgamma = RmsNormBwdPrimitive.inner_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
        return dx, dgamma

    @staticmethod
    def batcher(batched_args, batch_dims, *, epsilon):
        _check_valid_batch_dims(batch_dims)
        assert RmsNormBwdPrimitive.outer_primitive is not None
        dz, x, rsigma, gamma = batched_args
        _, x_bdim, _, gamma_bdim = batch_dims

        out_bdims = x_bdim, gamma_bdim
        return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma,
                                                        epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos):
        del epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        return dx_sharding, dgamma_sharding

    @staticmethod
    def partition(epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormBwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        g_spec = get_padded_spec(arg_infos[3])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec))
        out_shardings = dx_sharding, dgamma_sharding
        x_shardings = (dx_sharding,) * 2    # dz and x should have the same sharding.
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec)))

        def sharded_impl(dz, x, rsigma, gamma):
            local_dx, local_dgamma = \
                RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
            global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
            return local_dx, global_dgamma

        return mesh, sharded_impl, out_shardings, arg_shardings

865

866
register_primitive(RmsNormBwdPrimitive)
867
868


869
870
def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp.ndarray,
                epsilon: float):
871
    """
872
    Wrapper for TE layernorm bwd
873
    """
874
    return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)
875
876


877
class SoftmaxPrimitive(BasePrimitive):
878
    """
879
    Softmax Primitive
880
    """
881
    max_k_seqlen_supported = 4096
882
883

    @staticmethod
884
885
886
887
888
    @abstractmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError
889

890
891
892
893
894
    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
        threads_per_block = 128    # Depends on the kernel implmentation
895

896
897
898
899
900
901
        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block
902
903

    @staticmethod
904
    def forward_abstract(logits_aval, scale_factor):
905
        """
906
        softmax_forward abstract
907
        """
908
909
910
911
912
913
914
915
916
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
917

918
919
        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval
920

921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
        i_aval, = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        pad_batch = batch
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits]
        operand_shapes = [i_shape]
939
940
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

941
942
943
944
        opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
                                                                k_seqlen,
                                                                jax_dtype_to_te_dtype(i_aval.dtype),
                                                                scale_factor)
945

946
        out = custom_caller(name, args, opaque, False)
947
948
949

        return [out]

950
951
952
953
954
955
956
957
    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output
958

959
960
961
962
963
964
965
966
    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
        logits, = batched_args
        logits_bdim, = batch_dims
967

968
969
        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims
970

971
972
973
974
975
976
977
978
979
    @staticmethod
    def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding
980
981

    @staticmethod
982
    def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
983
        """
984
        softmax_forward partitioning
985
        """
986
987
988
989
990
991
992
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        out_spec = logits_spec
        arg_shardings = (logits_spec,)
        out_shardings = out_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
993

994
995
996
997
998
999
1000
1001
1002
1003
    @staticmethod
    def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None):    # pylint: disable=unused-argument
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]
1004

1005
        assert dz_aval.shape == softmax_out_aval.shape
1006

1007
1008
        dx_aval = core.raise_to_shaped(softmax_out_aval)
        return dx_aval
1009
1010

    @staticmethod
1011
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
1012
        """
1013
        softmax_backward lowering rules
1014
        """
1015
        dz_aval, _ = ctx.avals_in
1016

1017
1018
        dz_type = ir.RankedTensorType(dz.type)
        dz_shape = dz_type.shape
1019

1020
1021
1022
1023
1024
1025
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, dz_shape[:-3])
        pad_batch = batch    # unused
        heads = dz_shape[-3]
        q_seqlen = dz_shape[-2]
        k_seqlen = dz_shape[-1]
1026

1027
1028
        softmax_out_type = ir.RankedTensorType(softmax_out.type)
        softmax_out_shape = softmax_out_type.shape
1029

1030
1031
1032
        out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)]
        operands = [dz, softmax_out]
        operand_shapes = [dz_shape, softmax_out_shape]
1033
1034
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1035
1036
1037
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(dz_aval.dtype),
            scale_factor)
1038

1039
        out = custom_caller(name, args, opaque, False)
1040

1041
        return [out]
1042
1043

    @staticmethod
1044
    def backward_impl(primitive, dz, softmax_out, scale_factor):
1045
        """
1046
        softmax_backward implementation
1047
        """
1048
1049
1050
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx
1051

1052
1053
1054
1055
1056
1057
1058
1059
    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims
1060

1061
1062
        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims
1063
1064

    @staticmethod
1065
    def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
1066
        """
1067
        softmax_backward infer_sharding_from_operands
1068
        """
1069
1070
1071
1072
        del scale_factor, result_infos    # Unused.
        softmax_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec))
        return dx_sharding
1073

1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
    @staticmethod
    def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos
        dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        dx_spec = softmax_out_spec
        arg_shardings = (dz_spec, softmax_out_spec)
        out_shardings = dx_spec
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings
1087
1088


1089
1090
1091
1092
1093
1094
1095
1096
1097
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
    name = "te_scaled_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None
1098

1099
1100
1101
1102
1103
    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False
1116

1117
1118
1119
1120
1121
1122
    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)
1123

1124
1125
1126
1127
1128
1129
1130
1131
1132
    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)
1133

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(ScaledSoftmaxFwdPrimitive.inner_primitive, logits,
                                             scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(ScaledSoftmaxFwdPrimitive.outer_primitive,
                                                batched_args,
                                                batch_dims,
                                                scale_factor=scale_factor)
1146

1147
1148
1149
1150
1151
1152
1153
1154
1155
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor,
                                                  mesh, arg_infos, result_infos)
1156
1157


1158
register_primitive(ScaledSoftmaxFwdPrimitive)
1159

1160
1161

def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
1162
    """
1163
1164
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
1165
    """
1166
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
1167
1168


1169
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
1170
    """
1171
    Scaled Softmax Bwd Primitive
1172
    """
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
    name = "te_scaled_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)
1185
1186

    @staticmethod
1187
    def abstract(dz_aval, softmax_out_aval, scale_factor):
1188
        """
1189
        te_scaled_softmax_backward abstract
1190
        """
1191
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)
1192

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1203

1204
        return out
1205

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1225
1226

    @staticmethod
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor,
                                                   mesh, arg_infos, result_infos)


register_primitive(ScaledSoftmaxBwdPrimitive)


def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                       scale_factor: float) -> jnp.ndarray:
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                          softmax_out,
                                                          scale_factor=scale_factor)


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, mask_aval, scale_factor):    # pylint: disable=unused-argument
1276
        """
1277
        te_scaled_masked_softmax_forward abstract
1278
1279
        """

1280
1281
1282
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
1283

1284
1285
1286
1287
1288
1289
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1
1290

1291
1292
1293
        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
1294
        ]
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
        assert pad_batch in (1, batch)    # 1 means broadcast
        assert mask_shape[-3] == 1    # 1 means broadcast
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """

        logits_aval, _ = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        mask_type = ir.RankedTensorType(mask.type)
        mask_shape = mask_type.shape
        pad_batch = reduce(operator.mul, mask_shape[:-3])

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits, mask]
        operand_shapes = [i_shape, mask_shape]
1327
1328
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1329
1330
1331
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch, pad_batch, heads, q_seqlen, k_seqlen, jax_dtype_to_te_dtype(logits_aval.dtype),
            scale_factor)
1332

1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
        out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

        return [out]

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(logits,
                                                                      mask,
                                                                      scale_factor=scale_factor)
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
        return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
            logits, mask, scale_factor=scale_factor), out_bdims

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        del scale_factor, result_infos    # Unused.
        logits_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
        return out_sharding

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        del result_infos
        logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
        mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = (logits_spec, mask_spec)
        out_shardings = logits_spec
        impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits,
                                                                mask,
                                                                scale_factor=scale_factor)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
                                                             dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
                                              dz,
                                              softmax_out,
                                              scale_factor=scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
                                                 batched_args,
                                                 batch_dims,
                                                 scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                              scale_factor: float) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz,
                                                                softmax_out,
                                                                scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
    impl_static_args = (1,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
        if (dtype in [jnp.float16, jnp.bfloat16]
                and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        # k_seqlen must be 16 ~ 4096
                and q_seqlen % 4 == 0    # q_seqlen must be divisor of 4
                and attn_batches % 4 == 0    # batch * heads must be divisor of 4
           ):
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
    def abstract(logits_aval, scale_factor):    # pylint: disable=unused-argument
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
1498
1499
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
        return SoftmaxPrimitive.forward_lowering(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name,
                                                 ctx,
                                                 logits,
                                                 scale_factor=scale_factor)

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor)

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                     result_infos)

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
                                                  scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
    impl_static_args = (2,)    # scale_factor
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
                            dtype: jnp.dtype) -> bool:
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype)

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
        out = SoftmaxPrimitive.backward_lowering(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
                                                 ctx,
                                                 dz,
                                                 softmax_out,
                                                 scale_factor=scale_factor)
1584
1585
1586

        return out

1587
1588
1589
1590
1591
1592
1593
    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor)
1594

1595
1596
1597
1598
1599
1600
1601
1602
    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        _check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor)
1603

1604
1605
1606
1607
    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
                                                                      result_infos)
1608

1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
                                                   scale_factor, mesh, arg_infos, result_infos)


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray,
                                           scale_factor: float) -> jnp.ndarray:
1620
    """
1621
1622
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
1623
    """
1624
1625
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor)
1626
1627


1628
1629
@dataclass(frozen=True)
class FusedAttnHelper:
1630
    """
1631
    Helper for the fused attention backend
1632
    """
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713

    q_type: jnp.dtype
    kv_type: jnp.dtype
    qkv_layout: NVTE_QKV_Layout
    attn_bias_type: NVTE_Bias_Type
    attn_mask_type: NVTE_Mask_Type
    dropout_probability: float
    max_seqlen_q: int
    max_seqlen_kv: int
    head_dim: int

    def is_fused_attn_kernel_available(self):
        """Check if there is available fused attention kernel"""
        return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend

    def get_fused_attn_backend(self):
        """Get the fused attention kernel backend"""
        return transformer_engine_jax.get_fused_attn_backend(jax_dtype_to_te_dtype(self.q_type),
                                                             jax_dtype_to_te_dtype(self.kv_type),
                                                             self.qkv_layout, self.attn_bias_type,
                                                             self.attn_mask_type,
                                                             self.dropout_probability,
                                                             self.max_seqlen_q, self.max_seqlen_kv,
                                                             self.head_dim)


@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
    """
    Checker for guarding the fused attention rng state.
    The fused attention backend requires a 64 bits seed and a 64 bits offset.
    However, JAX doesn't enable 64 bits by default,
    so we have to emulate seed as two 32 bits array.
    The offset calculation is maintained in the backend.
    """
    rng_state_dtype: jnp.dtype = jnp.uint32
    # (seed,) with internal dtype int64
    seed_size: int = 2
    # (seed, offset) with internal dtype int64
    rng_state_size: int = 2 * 2

    def check_seed(self, seed, dropout_probability, is_training):
        """
        Check the seed and convert the data type of seed if possible.
        """
        # Jax can't bind None, create a dummy tensor for None
        if seed is None:
            dropout_enabled = dropout_probability > 0 and is_training
            assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
            seed = jnp.zeros(2, dtype=self.rng_state_dtype)
            seed = jnp.repeat(seed, num_of_devices())

        if seed.dtype != self.rng_state_dtype:
            warnings.warn(
                f"Requested {seed.dtype=} is not available, and will be "
                f"casted to dtype {self.rng_state_dtype}. "
                f"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning.")
            seed = seed.astype(self.rng_state_dtype)

        assert seed.dtype == self.rng_state_dtype
        # Backend takes an int64_t seed, so only the first two u32 elements are taken
        assert seed.size >= self.seed_size

        return seed


def generate_cu_seqlen(mask):
    """
    Generating cumsum seqlen for a batch
    """
    seqlen = jnp.sum(mask == 0, axis=(-1, -2), dtype=jnp.int32)
    cu_seqlen = jnp.cumsum(seqlen)
    cu_seqlen = jnp.hstack((0, cu_seqlen))
    return cu_seqlen


class SelfFusedAttnFwdPrimitive(BasePrimitive):
    """
    Self Fused Attention Forward Primitive
    """
    name = "te_self_fused_attn_forward"
1714
    multiple_results = True
1715
1716
1717
    impl_static_args = (4, 5, 6, 7, 8)
    inner_primitive = None
    outer_primitive = None
1718
1719

    @staticmethod
1720
1721
    def abstract(qkv_aval, bias_aval, mask_or_cu_seqlen_aval, seed_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
1722
        """
1723
        Self fused attention fwd abstract
1724
        """
1725
1726
1727
1728
1729
1730
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del mask_or_cu_seqlen_aval, scaling_factor, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
        *batch_shape, max_seqlen, nqkv, num_head, head_dim = qkv_aval.shape
        assert nqkv == 3
        assert qkv_aval.dtype == bias_aval.dtype
1731

1732
1733
        output_shape = (*batch_shape, max_seqlen, num_head, head_dim)
        output_dtype = qkv_dtype
1734

1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
        backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type,
                                  attn_mask_type, dropout_probability, max_seqlen, max_seqlen,
                                  head_dim).get_fused_attn_backend()

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, max_seqlen)
            softmax_dtype = qkv_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*batch_shape, num_head, max_seqlen, 1)
            softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
1746
            raise ValueError(f'Unsupported {backend=}')
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype

        out_aval = qkv_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = qkv_aval.update(shape=softmax_aux_shape, dtype=softmax_dtype)
        rng_state_aval = qkv_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)
        return out_aval, softmax_aux_aval, rng_state_aval
1758
1759

    @staticmethod
1760
1761
    def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1762
        """
1763
        Self fused attention fwd lowering rules
1764
        """
1765
        qkv_aval, _, _, _ = ctx.avals_in
1766

1767
1768
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1769

1770
1771
        operands = [qkv, bias, cu_seqlen, seed]
        operand_shapes = map(lambda x: x.type.shape, operands)
1772
        out_types = [
1773
1774
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1775
        ]
1776

1777
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
1778
1779
1780
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1781

1782
        out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
1783

1784
1785
1786
1787
1788
1789
        return out

    @staticmethod
    def impl(qkv, bias, squeezed_mask, seed, attn_bias_type, attn_mask_type, scaling_factor,
             dropout_probability, is_training):
        assert SelfFusedAttnFwdPrimitive.inner_primitive is not None
1790

1791
        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1792

1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
        output, softmax_aux, rng_state = SelfFusedAttnFwdPrimitive.inner_primitive.bind(
            qkv,
            bias,
            cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return output, softmax_aux, rng_state
1804

1805
1806
1807
1808
1809
1810
1811
1812
1813
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnFwdPrimitive.outer_primitive is not None
        qkv_bdim, _, _, seed_bdim = batch_dims

        out_bdims = qkv_bdim, qkv_bdim, seed_bdim
        return SelfFusedAttnFwdPrimitive.outer_primitive.bind(
1814
            *batched_args,
1815
1816
1817
1818
1819
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1820

1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
1833

1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])    # (...batch, seqlen, 3, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:]))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None))
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding])
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
        impl = partial(SelfFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnFwdPrimitive)


def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray, squeezed_mask: jnp.ndarray,
                        seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
1861
    """
1862
1863
    Wrapper for TE self fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
1864
    """
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
    return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv,
                                                          bias,
                                                          squeezed_mask,
                                                          seed,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
1880
1881


1882
class SelfFusedAttnBwdPrimitive(BasePrimitive):
1883
    """
1884
    Self Fused Attention Backward Primitive
1885
    """
1886
    name = "te_self_fused_attn_backward"
1887
    multiple_results = True
1888
    impl_static_args = (7, 8, 9, 10, 11)
1889
1890
    inner_primitive = None
    outer_primitive = None
1891
1892

    @staticmethod
1893
    def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval,
1894
1895
                 mask_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
1896
        """
1897
        Self fused attention bwd abstract
1898
        """
1899
1900
        del softmax_aux_aval, rng_state_aval
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
1901
        del mask_or_cu_seqlen_aval, attn_bias_type, attn_mask_type
1902
1903
        del scaling_factor, dropout_probability, is_training
        qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype)
1904
1905
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
        assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype
1906

1907
        dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype)
1908
1909
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
        return dqkv_aval, dbias_aval
1910
1911

    @staticmethod
1912
1913
    def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *,
                 attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training):
1914
        """
1915
        Self fused attention bwd lowering rules
1916
        """
1917
        qkv_aval, _, _, _, _, _, _ = ctx.avals_in
1918

1919
1920
        *batch_shape, max_seqlen, _, num_head, head_dim = qkv_aval.shape
        batch = reduce(operator.mul, batch_shape)
1921

1922
        operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen]
1923
        operand_shapes = map(lambda x: x.type.shape, operands)
1924
        out_types = [
1925
1926
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
1927
        ]
1928

1929
1930
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

1931
1932
1933
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, max_seqlen, max_seqlen, head_dim, scaling_factor, dropout_probability,
            attn_bias_type, attn_mask_type, jax_dtype_to_te_dtype(qkv_aval.dtype), is_training)
1934

1935
        out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
1936
1937
1938

        return out

1939
    @staticmethod
1940
    def impl(qkv, bias, softmax_aux, rng_state, output, doutput, squeezed_mask, attn_bias_type,
1941
1942
1943
1944
             attn_mask_type, scaling_factor, dropout_probability, is_training):
        assert SelfFusedAttnBwdPrimitive.inner_primitive is not None

        cu_seqlen = generate_cu_seqlen(squeezed_mask)
1945

1946
1947
        dqkv, dbias = SelfFusedAttnBwdPrimitive.inner_primitive.bind(
            qkv,
1948
            bias,
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
            softmax_aux,
            rng_state,
            output,
            doutput,
            cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
        return dqkv, dbias
1960

1961
1962
1963
1964
1965
1966
1967
1968
1969
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert SelfFusedAttnBwdPrimitive.outer_primitive is not None
        qkv_bdim, *_ = batch_dims

        out_bdims = qkv_bdim, qkv_bdim
        return SelfFusedAttnBwdPrimitive.outer_primitive.bind(
1970
            *batched_args,
1971
1972
1973
1974
1975
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims
1976

1977
1978
1979
1980
    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
1981
        del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
1982
1983
        del is_training, result_infos
        x_spec = get_padded_spec(arg_infos[0])
1984
        bias_spec = get_padded_spec(arg_infos[1])
1985
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
1986
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1987
1988
1989
1990
1991
1992
1993
        return (dx_sharding, dbias_sharding)

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
1994
        bias_spec = get_padded_spec(arg_infos[1])
1995
        dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
1996
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
1997
1998
1999
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (dx_sharding, dbias_sharding)

2000
        def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen):
2001
2002
            local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl(
                qkv,
2003
                bias,
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
                softmax_aux,
                rng_state,
                output,
                doutput,
                cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dx, global_dbias

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(SelfFusedAttnBwdPrimitive)


2025
2026
2027
2028
2029
def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray,
                        rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray,
                        squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                        attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                        dropout_probability: float, is_training: bool):
2030
    """
2031
2032
    Wrapper for TE self fused attention bwd
    Return the gradients of self fused attention with packed qkv input
2033
    """
2034
2035
2036
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=qkv.dtype)
2037
    return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv,
2038
                                                          bias,
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
                                                          softmax_aux,
                                                          rng_state,
                                                          output,
                                                          doutput,
                                                          squeezed_mask,
                                                          attn_bias_type=attn_bias_type,
                                                          attn_mask_type=attn_mask_type,
                                                          scaling_factor=scaling_factor,
                                                          dropout_probability=dropout_probability,
                                                          is_training=is_training)
2049
2050


2051
class CrossFusedAttnFwdPrimitive(BasePrimitive):
2052
    """
2053
    Cross Fused Attention Forward Primitive
2054
    """
2055
    name = "te_cross_fused_attn_forward"
2056
    multiple_results = True
2057
    impl_static_args = (6, 7, 8, 9, 10)
2058
2059
    inner_primitive = None
    outer_primitive = None
2060
2061

    @staticmethod
2062
2063
2064
    def abstract(q_aval, kv_aval, bias_aval, q_mask_or_cu_seqlen_aval, kv_mask_or_cu_seqlen_aval,
                 seed_aval, *, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
                 is_training):
2065
        """
2066
        Cross fused attention fwd abstract
2067
        """
2068
2069
        # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen
        del scaling_factor, is_training
2070

2071
2072
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        *q_batch_shape, q_max_seqlen, q_num_head, q_head_dim = q_aval.shape
2073

2074
2075
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
        *kv_batch_shape, kv_max_seqlen, nkv, kv_num_head, kv_head_dim = kv_aval.shape
2076

2077
2078
2079
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)

        assert q_dtype == kv_dtype == bias_dtype
2080
2081
2082
2083
2084
2085
2086
2087
        assert q_batch_shape == kv_batch_shape
        assert q_num_head == kv_num_head
        assert q_head_dim == kv_head_dim
        assert nkv == 2
        assert q_mask_or_cu_seqlen_aval.dtype == kv_mask_or_cu_seqlen_aval.dtype

        output_shape = q_aval.shape
        output_dtype = q_dtype
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106

        backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD,
                                  attn_bias_type, attn_mask_type, dropout_probability, q_max_seqlen,
                                  kv_max_seqlen, q_head_dim).get_fused_attn_backend()

        if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
            softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, kv_max_seqlen)
            softmax_aux_dtype = q_dtype
        elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
            softmax_aux_shape = (*q_batch_shape, q_num_head, q_max_seqlen, 1)
            softmax_aux_dtype = dtypes.canonicalize_dtype(jnp.float32)
        else:
            raise ValueError(f'Unsupported {backend=}')

        checker = _FusedAttnRNGStateChecker()
        seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
        assert seed_dtype == checker.rng_state_dtype
        rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
        rng_state_dtype = seed_dtype
2107
2108
2109

        out_aval = q_aval.update(shape=output_shape, dtype=output_dtype)
        softmax_aux_aval = q_aval.update(shape=softmax_aux_shape, dtype=softmax_aux_dtype)
2110
2111
2112
        rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=rng_state_dtype)

        return out_aval, softmax_aux_aval, rng_state_aval
2113
2114

    @staticmethod
2115
2116
    def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2117
        """
2118
        Cross fused attention fwd lowering rules
2119
        """
2120
        q_aval, kv_aval, *_ = ctx.avals_in
2121
        assert q_aval.dtype == kv_aval.dtype
2122

2123
2124
2125
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2126

2127
        operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed]
2128
        operand_shapes = map(lambda x: x.type.shape, operands)
2129
        out_types = [
2130
2131
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2132
2133
        ]

2134
2135
2136
2137
2138
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2139

2140
        out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
2141
2142
2143

        return out

2144
    @staticmethod
2145
    def impl(q, kv, bias, q_squeezed_mask, kv_squeezed_mask, seed, attn_bias_type, attn_mask_type,
2146
2147
             scaling_factor, dropout_probability, is_training):
        assert CrossFusedAttnFwdPrimitive.inner_primitive is not None
2148

2149
2150
        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)
2151

2152
        output, softmax_aux, rng_state = CrossFusedAttnFwdPrimitive.inner_primitive.bind(
2153
2154
            q,
            kv,
2155
            bias,
2156
2157
2158
2159
2160
2161
2162
2163
            q_cu_seqlen,
            kv_cu_seqlen,
            seed,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2164
        return output, softmax_aux, rng_state
2165

2166
2167
2168
2169
2170
    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnFwdPrimitive.outer_primitive is not None
2171
        q_bdim, *_, seed_bdim = batch_dims
2172

2173
        out_bdims = q_bdim, q_bdim, seed_bdim
2174
        return CrossFusedAttnFwdPrimitive.outer_primitive.bind(
2175
            *batched_args,
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
2193
2194
        rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
        return (out_sharding, softmax_aux_sharding, rng_state_sharding)
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])    # (...batch, q_seqlen, head, hidden)
        kv_spec = get_padded_spec(arg_infos[1])    # (...batch, kv_seqlen, 2, head, hidden)
        out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        softmax_aux_sharding = NamedSharding(
            mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4]))
2205
2206
        rng_state_sharding = seed_sharding = NamedSharding(mesh,
                                                           PartitionSpec(get_all_mesh_axes(), None))
2207
        arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding])
2208
        out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
        impl = partial(CrossFusedAttnFwdPrimitive.impl,
                       attn_bias_type=attn_bias_type,
                       attn_mask_type=attn_mask_type,
                       scaling_factor=scaling_factor,
                       dropout_probability=dropout_probability,
                       is_training=is_training)
        return mesh, impl, out_shardings, arg_shardings


register_primitive(CrossFusedAttnFwdPrimitive)


2221
2222
2223
2224
2225
def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                         q_squeezed_mask: jnp.ndarray, kv_squeezed_mask: jnp.ndarray,
                         seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2226
    """
2227
2228
    Wrapper for TE cross fused attention fwd
    Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
2229
    """
2230
2231
2232
    checker = _FusedAttnRNGStateChecker()
    seed = checker.check_seed(seed, dropout_probability, is_training)

2233
2234
2235
2236
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)

2237
2238
    return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q,
                                                           kv,
2239
                                                           bias,
2240
2241
2242
2243
2244
2245
2246
2247
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           seed,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2248
2249


2250
class CrossFusedAttnBwdPrimitive(BasePrimitive):
2251
    """
2252
    Cross Fused Attention Backward Primitive
2253
    """
2254
    name = "te_cross_fused_attn_backward"
2255
    multiple_results = True
2256
    impl_static_args = (9, 10, 11, 12, 13)
2257
2258
    inner_primitive = None
    outer_primitive = None
2259
2260

    @staticmethod
2261
2262
2263
    def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval,
                 doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type,
                 attn_mask_type, scaling_factor, dropout_probability, is_training):
2264
        """
2265
        Cross fused attention bwd abstract
2266
        """
2267
2268
        del softmax_aux_aval, rng_state_aval, output_aval
        del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training
2269
2270
        q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
        kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype)
2271
        bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
2272
        doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
2273
        assert q_dtype == kv_dtype == bias_dtype == doutput_dtype
2274
        assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype
2275

2276
2277
        dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
        dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype)
2278
2279
        dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
        return dq_aval, dkv_aval, dbias_aval
2280
2281

    @staticmethod
2282
2283
2284
    def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                 kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor,
                 dropout_probability, is_training):
2285
        """
2286
        Cross fused attention bwd lowering rules
2287
        """
2288
        q_aval, kv_aval, *_ = ctx.avals_in
2289
        assert q_aval.dtype == kv_aval.dtype
2290

2291
2292
2293
        *batch_shape, q_max_seqlen, num_head, head_dim = q_aval.shape
        batch = reduce(operator.mul, batch_shape)
        kv_max_seqlen = kv_aval.shape[-4]
2294

2295
        operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen]
2296
        operand_shapes = map(lambda x: x.type.shape, operands)
2297
        out_types = [
2298
2299
            ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
            for output in ctx.avals_out
2300
        ]
2301

2302
2303
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2304
2305
2306
2307
2308
2309
        # the dropout elements are encoded in the forward auxiliary tensor
        # so seed is not needed in backward
        opaque = transformer_engine_jax.pack_fused_attn_descriptor(
            batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim,
            scaling_factor, dropout_probability, attn_bias_type, attn_mask_type,
            jax_dtype_to_te_dtype(q_aval.dtype), is_training)
2310

2311
        out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
2312
2313
2314

        return out

2315
    @staticmethod
2316
2317
2318
    def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_squeezed_mask,
             kv_squeezed_mask, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability,
             is_training):
2319
2320
2321
2322
2323
        assert CrossFusedAttnBwdPrimitive.inner_primitive is not None

        q_cu_seqlen = generate_cu_seqlen(q_squeezed_mask)
        kv_cu_seqlen = generate_cu_seqlen(kv_squeezed_mask)

2324
        dq, dkv, dbias = CrossFusedAttnBwdPrimitive.inner_primitive.bind(
2325
2326
            q,
            kv,
2327
            bias,
2328
            softmax_aux,
2329
2330
            rng_state,
            output,
2331
2332
2333
2334
2335
2336
2337
2338
            doutput,
            q_cu_seqlen,
            kv_cu_seqlen,
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training)
2339
        return dq, dkv, dbias
2340
2341
2342
2343
2344
2345
2346
2347

    @staticmethod
    def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor,
                dropout_probability, is_training):
        _check_valid_batch_dims(batch_dims)
        assert CrossFusedAttnBwdPrimitive.outer_primitive is not None
        q_bdim, kv_bdim, *_ = batch_dims

2348
        out_bdims = q_bdim, kv_bdim, q_bdim
2349
        return CrossFusedAttnBwdPrimitive.outer_primitive.bind(
2350
            *batched_args,
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
            attn_bias_type=attn_bias_type,
            attn_mask_type=attn_mask_type,
            scaling_factor=scaling_factor,
            dropout_probability=dropout_probability,
            is_training=is_training), out_bdims

    @staticmethod
    def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor,
                                     dropout_probability, is_training, mesh, arg_infos,
                                     result_infos):
        del attn_bias_type, attn_mask_type, scaling_factor
        del dropout_probability, is_training, result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
2365
        bias_spec = get_padded_spec(arg_infos[2])
2366
2367
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
2368
2369
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
        return (dq_sharding, dkv_sharding, dbias_sharding)
2370
2371
2372
2373
2374
2375
2376

    @staticmethod
    def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training,
                  mesh, arg_infos, result_infos):
        del result_infos
        q_spec = get_padded_spec(arg_infos[0])
        kv_spec = get_padded_spec(arg_infos[1])
2377
        bias_spec = get_padded_spec(arg_infos[2])
2378
2379
        dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
        dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec))
2380
        dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
2381
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
        out_shardings = (dq_sharding, dkv_sharding, dbias_sharding)

        def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen,
                         kv_cu_seqlen):
            local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl(
                q,
                kv,
                bias,
                softmax_aux,
                rng_state,
                output,
                doutput,
                q_cu_seqlen,
                kv_cu_seqlen,
                attn_bias_type=attn_bias_type,
                attn_mask_type=attn_mask_type,
                scaling_factor=scaling_factor,
                dropout_probability=dropout_probability,
                is_training=is_training)
            global_dbias = local_dbias
            if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
                global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            return local_dq, local_dkv, global_dbias
2405

2406
        return mesh, sharded_impl, out_shardings, arg_shardings
2407

2408

2409
register_primitive(CrossFusedAttnBwdPrimitive)
2410
2411


2412
2413
def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray,
                         softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray,
2414
2415
2416
2417
                         doutput: jnp.ndarray, q_squeezed_mask: jnp.ndarray,
                         kv_squeezed_mask: jnp.ndarray, attn_bias_type: NVTE_Bias_Type,
                         attn_mask_type: NVTE_Mask_Type, scaling_factor: float,
                         dropout_probability: float, is_training: bool):
2418
    """
2419
2420
    Wrapper for TE cross fused attention bwd
    Return the gradients of cross fused attention with packed kv input
2421
    """
2422
2423
2424
    if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS:
        assert bias is None
        bias = jnp.zeros(0, dtype=q.dtype)
2425
2426
    return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q,
                                                           kv,
2427
                                                           bias,
2428
                                                           softmax_aux,
2429
2430
                                                           rng_state,
                                                           output,
2431
2432
2433
2434
2435
2436
2437
2438
                                                           doutput,
                                                           q_squeezed_mask,
                                                           kv_squeezed_mask,
                                                           attn_bias_type=attn_bias_type,
                                                           attn_mask_type=attn_mask_type,
                                                           scaling_factor=scaling_factor,
                                                           dropout_probability=dropout_probability,
                                                           is_training=is_training)
2439
2440


2441
class GatedGeluPrimitive(BasePrimitive):
2442
    """
2443
    Gated Gelu Froward Primitive
2444
    """
2445
    name = "te_gated_gelu"
2446
    multiple_results = False
2447
2448
2449
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2450
2451

    @staticmethod
2452
    def abstract(x_aval):
2453
        """
2454
        gated_gelu abstract
2455
        """
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        x_shape = x_aval.shape
        assert x_shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_shape[-1]
        batch_shapes = x_shape[:-2]
        x_shape = x_aval.shape
        out_aval = core.raise_to_shaped(x_aval)
        out_shape = (batch_shapes) + (hidden_size,)
        out_aval = out_aval.update(shape=out_shape, dtype=dtype)
2466

2467
        return out_aval
2468
2469

    @staticmethod
2470
    def lowering(ctx, x):
2471
        """
2472
        gated_gelu lowering rules
2473
        """
2474
2475
2476
2477
2478
        (x_aval,) = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
2479

2480
2481
2482
2483
2484
2485
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_x_type.element_type),
        ]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
2486

2487
2488
2489
2490
2491
        hidden_size = ir_x_shape[-1]
        batch_size = reduce(operator.mul, ir_x_shape[:-2])
        in_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size), in_dtype,
                                                               in_dtype)
2492

2493
        out = custom_caller(GatedGeluPrimitive.name, args, opaque, False)
2494

2495
        return [out]
2496

2497
2498
2499
2500
2501
    @staticmethod
    def impl(x):
        assert GatedGeluPrimitive.inner_primitive is not None
        out = GatedGeluPrimitive.inner_primitive.bind(x)
        return out
2502

2503
2504
2505
2506
2507
2508
2509
2510
2511
    @staticmethod
    def batcher(batched_args, batch_dims):
        """
        gated_gelu batcher
        """
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluPrimitive.outer_primitive is not None
        inputs, = batched_args
        inputs_bdim, = batch_dims
2512

2513
2514
        out_bdims = inputs_bdim
        return GatedGeluPrimitive.outer_primitive.bind(inputs), out_bdims
2515

2516
2517
2518
2519
2520
2521
2522
2523
2524
    @staticmethod
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
        """
        gated_gelu infer_sharding_from_operands
        """
        del result_infos    # Unused.
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        return out_sharding
2525

2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        gated_gelu partitioning
        """
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        impl = GatedGeluPrimitive.impl
        return mesh, impl, out_sharding, arg_shardings
2537
2538


2539
register_primitive(GatedGeluPrimitive)
2540
2541


2542
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray:
2543
    """
2544
2545
2546
    gated gelu wrapper
    Return FP8(geglu(inputs))
    Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
2547
    """
2548
    return GatedGeluPrimitive.outer_primitive.bind(inputs)
2549
2550


2551
class DgatedGeluPrimitive(BasePrimitive):
2552
    """
2553
    Dgated Gelu Primitive
2554
    """
2555
2556
2557
2558
2559
    name = "te_dgated_gelu"
    multiple_results = False
    inner_primitive = None
    outer_primitive = None
    impl_static_args = ()
2560
2561

    @staticmethod
2562
    def abstract(dz_aval, x_aval):
2563
        """
2564
        dgated_gelu abstract
2565
        """
2566
2567
2568
2569
2570
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        for axis in range(len(dz_aval.shape) - 1):
            assert dz_aval.shape[axis] == x_aval.shape[axis]
2571

2572
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
2573

2574
2575
2576
2577
2578
        i_hidden_size = dz_aval.shape[-1]
        g_hidden_size = x_aval.shape[-1]
        assert i_hidden_size == g_hidden_size
        out_aval = core.raise_to_shaped(x_aval)
        return out_aval
2579
2580

    @staticmethod
2581
    def lowering(ctx, dz, x):
2582
        """
2583
        dgated_gelu lowering rules
2584
        """
2585
2586
2587
2588
2589
2590
2591
2592
2593
        in_aval, gi_aval = ctx.avals_in
        assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gi_aval.dtype == in_aval.dtype
        ir_in_type = ir.RankedTensorType(dz.type)
        ir_in_shape = ir_in_type.shape
        gi_type = ir.RankedTensorType(x.type)
        gi_shape = gi_type.shape
        for axis in range(len(ir_in_shape) - 1):
            assert ir_in_shape[axis] == gi_shape[axis]
2594

2595
2596
2597
2598
2599
2600
        ir_batch_size = reduce(operator.mul, ir_in_shape[:-1])
        i_hidden_size = ir_in_shape[-1]
        g_hidden_size = gi_shape[-1]
        assert i_hidden_size == g_hidden_size
        out_dtype = ir_in_type.element_type
        out_shape = gi_shape
2601
2602

        out_types = [
2603
            ir.RankedTensorType.get(out_shape, out_dtype),
2604
        ]
2605
2606
        operands = [dz, x]
        operand_shapes = [ir_in_shape, gi_shape]
2607
2608
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2609
2610
2611
        in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
        opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_size, i_hidden_size),
                                                               in_dtype, in_dtype)
2612

2613
        out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False)
2614
2615
2616
2617

        return [out]

    @staticmethod
2618
2619
2620
2621
2622
2623
2624
    def impl(dz, x):
        """
        dgated_gelu implementation
        """
        assert DgatedGeluPrimitive.inner_primitive is not None
        dx = DgatedGeluPrimitive.inner_primitive.bind(dz, x)
        return dx
2625
2626

    @staticmethod
2627
    def batcher(batched_args, batch_dims):
2628
        """
2629
        dgated_gelu batcher
2630
        """
2631
2632
2633
2634
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluPrimitive.outer_primitive is not None
        dz, x = batched_args
        _, x_bdim = batch_dims
2635

2636
2637
        out_bdims = x_bdim
        return DgatedGeluPrimitive.outer_primitive.bind(dz, x), out_bdims
2638
2639

    @staticmethod
2640
    def infer_sharding_from_operands(mesh, arg_infos, result_infos):
2641
        """
2642
        dgated_gelu infer_sharding_from_operands
2643
        """
2644
2645
2646
2647
        del result_infos    # Unused.
        gelu_out_spec = get_padded_spec(arg_infos[1])
        dx_sharding = NamedSharding(mesh, PartitionSpec(*gelu_out_spec))
        return dx_sharding
2648

2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
    @staticmethod
    def partition(mesh, arg_infos, result_infos):
        """
        dgated_gelu partition
        """
        del result_infos
        dx_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = dx_sharding
        impl = DgatedGeluPrimitive.impl
        return mesh, impl, out_shardings, arg_shardings
2660
2661


2662
register_primitive(DgatedGeluPrimitive)
2663
2664


2665
2666
2667
2668
2669
2670
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
    """
    dgated_gelu fusion wrapper
    Return dgeglu(inputs)
    """
    return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs)
2671
2672


2673
2674
def _normalize_axis_boundary(axis, ndim):
    return axis if axis >= 0 else ndim + axis
2675
2676


2677
def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary):
2678
    """
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
    te_cast_transpose_p multi-dims transpose

    static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
        involved into transpose, -1 means all axes involve into transpose.
    transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for
        transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary

    examples:
        X in shape (dim0, dim1, dim2, dim3, dim4)

        static_axis_boundary == -1, transpose_axis_boundary == 2
            Xt = (dim2, dim3, dim4, dim0, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 2
            Xt = (dim0, dim2, dim3, dim4, dim1)

        static_axis_boundary == 0, transpose_axis_boundary == 3
            Xt = (dim0, dim3, dim4, dim1. dim2)
2697
    """
2698
2699
2700
2701
2702
2703
2704
2705
    if static_axis_boundary < 0:
        static_axis_boundary = -1    # means no static axes
    assert static_axis_boundary < len(shape) - 2    # at least 2 remaining for transpose.
    transpose_start_idx = static_axis_boundary + 1
    transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape))
    assert transpose_start_idx < transpose_axis_boundary
    return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:],
            *shape[transpose_start_idx:transpose_axis_boundary])
2706
2707


2708
class CastTransposePrimitive(BasePrimitive):
2709
    """
2710
    Cast Transpose Primitive
2711
    """
2712
2713
2714
2715
2716
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None
2717
2718

    @staticmethod
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval
2738
2739

    @staticmethod
2740
2741
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary,
                 transpose_axis_boundary):
2742
        """
2743
        te_cast_transpose_p lowering rules
2744
        """
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(CastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 2})

        return out
2787
2788

    @staticmethod
2789
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
2790
        """
2791
        te_cast_transpose implementation
2792
        """
2793
2794
2795
2796
2797
2798
2799
        assert CastTransposePrimitive.inner_primitive is not None
        casted_x, casted_transposed_x, updated_amax = \
            CastTransposePrimitive.inner_primitive.bind(
                x, amax, scale, scale_inv, out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary)
        return casted_x, casted_transposed_x, updated_amax
2800

2801
2802
2803
2804
2805
2806
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary,
                transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2807

2808
2809
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims
2810

2811
2812
2813
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2814

2815
2816
2817
2818
2819
2820
2821
2822
2823
        out_bdims = x_bdim, x_bdim, amax_bdim
        return CastTransposePrimitive.outer_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2824

2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh,
                                     arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                  result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
            local_cx, local_cxt, local_updated_amax = \
                CastTransposePrimitive.impl(x, amax, scale, scale_inv,
                     out_dtype=out_dtype,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype, static_axis_boundary: int,
                   transpose_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
2867
    """
2868
2869
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
2870
    """
2871
2872
2873
2874
2875
2876
2877
2878
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary)
2879
2880


2881
class TransposePrimitive(BasePrimitive):
2882
    """
2883
    Transpose Primitive
2884
    """
2885
    name = "te_transpose"
2886
    multiple_results = False
2887
2888
2889
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None
2890
2891

    @staticmethod
2892
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
2893
        """
2894
        _transpose abstract
2895
        """
2896
2897
2898
        transposed_x_shape = _multidim_transpose(x_aval.shape, static_axis_boundary,
                                                 transpose_axis_boundary)
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)
2899

2900
        return xt_aval
2901
2902

    @staticmethod
2903
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
2904
        """
2905
        _transpose cuda lowering
2906
2907
        """

2908
2909
2910
2911
        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
            jnp.float32, jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2
        ]
2912

2913
2914
2915
2916
2917
2918
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
2919

2920
2921
2922
2923
2924
2925
        transposed_x_shape = _multidim_transpose(ir_x_shape, static_axis_boundary,
                                                 transpose_axis_boundary)

        out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
        operands = [x]
        operand_shapes = [ir_x_shape]
2926
2927
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

2928
2929
2930
2931
2932
        te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
        contracted_x_shape = (reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                              reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]))
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape, te_dtype,
                                                               te_dtype)
2933

2934
        out = custom_caller(TransposePrimitive.name, args, opaque, False)
2935
2936
2937

        return [out]

2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
        transposed_x = \
            TransposePrimitive.inner_primitive.bind(x,
                                                    static_axis_boundary=static_axis_boundary,
                                                    transpose_axis_boundary=transpose_axis_boundary)
        return transposed_x
2949

2950
2951
2952
2953
2954
    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        _check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0
2955

2956
2957
        x, = batched_args
        x_bdim, = batch_dims
2958

2959
2960
2961
        # Minus batch dim.
        transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
        transpose_axis_boundary += 1    # Plus batch dim
2962

2963
2964
2965
2966
        out_bdims = x_bdim
        return TransposePrimitive.outer_primitive.bind(
            x, static_axis_boundary=x_bdim,
            transpose_axis_boundary=transpose_axis_boundary), out_bdims
2967
2968

    @staticmethod
2969
2970
2971
2972
2973
2974
2975
    def infer_sharding_from_operands(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding
2976
2977

    @staticmethod
2978
2979
2980
2981
2982
2983
2984
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding
2985

2986
2987
2988
        impl = partial(TransposePrimitive.impl,
                       static_axis_boundary=static_axis_boundary,
                       transpose_axis_boundary=transpose_axis_boundary)
2989

2990
        return mesh, impl, out_shardings, arg_shardings
2991
2992


2993
register_primitive(TransposePrimitive)
2994
2995


2996
2997
def transpose(x: jnp.ndarray, static_axis_boundary: int,
              transpose_axis_boundary: int) -> jnp.ndarray:
2998
    """
2999
    transpose wrapper
3000
    """
3001
3002
3003
    return TransposePrimitive.outer_primitive.bind(x,
                                                   static_axis_boundary=static_axis_boundary,
                                                   transpose_axis_boundary=transpose_axis_boundary)
3004
3005


3006
class LayerNormFwdFp8Primitive(BasePrimitive):
3007
    """
3008
    Layer Normalization Forward FP8 Primitive
3009
    """
3010
3011
3012
3013
3014
    name = "te_layernorm_forward_fp8"
    multiple_results = True
    impl_static_args = (6, 7, 8)    # out_type, zero_centered_gamma, epsilon
    inner_primitive = None
    outer_primitive = None
3015
3016

    @staticmethod
3017
3018
    def abstract(x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 zero_centered_gamma, epsilon):
3019
        """
3020
        LayerNorm fwd (fp8 out) abstract
3021
        """
3022
3023
        del zero_centered_gamma, epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3024

3025
3026
3027
3028
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3029

3030
3031
3032
3033
3034
3035
3036
3037
3038
        mu_rsigama_dtype = jnp.float32

        assert gamma_aval.size == beta_aval.size

        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return out_aval, mu_aval, rsigma_aval, updated_amax_aval
3039
3040

    @staticmethod
3041
3042
    def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, out_dtype, zero_centered_gamma,
                 epsilon):
3043
        """
3044
        LayerNorm fwd (fp8 out) lowering rules
3045
        """
3046
        x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3047

3048
3049
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3050

3051
3052
3053
3054
3055
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert gamma_aval.dtype == beta_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3056

3057
3058
3059
3060
3061
3062
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape
        b_type = ir.RankedTensorType(beta.type)
        b_shape = b_type.shape
3063

3064
3065
        assert g_type == b_type
        assert g_shape == b_shape
3066

3067
3068
3069
3070
3071
3072
3073
3074
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_mu_dtype = ir.F32Type.get()
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3075

3076
3077
3078
3079
        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3080

3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, beta, amax, scale, scale_inv]
        operand_shapes = [
            x_shape, g_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape
        ]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3092

3093
3094
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3095
3096
3097
3098
3099
3100
3101
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            zero_centered_gamma,
            epsilon,
3102
            sm_margin,
3103
        )
3104

3105
3106
3107
3108
3109
        out = custom_caller(LayerNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={3: 3})
3110

3111
        return out
3112
3113

    @staticmethod
3114
    def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, epsilon):
3115
        """
3116
        to describe implementation
3117
        """
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
        assert LayerNormFwdFp8Primitive.inner_primitive is not None
        out, mu, rsigma, updated_amax = LayerNormFwdFp8Primitive.inner_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon)
        return out, mu, rsigma, updated_amax
3130
3131

    @staticmethod
3132
    def batcher(batched_args, batch_dims, *, out_dtype, zero_centered_gamma, epsilon):
3133
        """
3134
        to describe batch rules for vmap
3135
        """
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
        _check_valid_batch_dims(batch_dims)
        assert LayerNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, beta, amax, scale, scale_inv = batched_args
        x_bdim, _, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
        return LayerNormFwdFp8Primitive.outer_primitive.bind(
            x,
            gamma,
            beta,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            zero_centered_gamma=zero_centered_gamma,
            epsilon=epsilon), out_bdims

    @staticmethod
    def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos,
                                     result_infos):
        del out_dtype, zero_centered_gamma, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance.")

        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        return (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)

    @staticmethod
    def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        out_sharding = x_sharding
        mu_sharding = rsigma_sharding = NamedSharding(
            mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[3])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding, b_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, mu_sharding, rsigma_sharding, amax_sharding)
3189

3190
3191
3192
3193
3194
3195
3196
        def sharded_impl(x, gamma, beta, amax, scale, scale_inv):
            local_x, local_mu, local_rsigma, local_amax = \
                LayerNormFwdFp8Primitive.impl(x, gamma, beta, amax, scale, scale_inv,
                                            out_dtype=out_dtype,
                                            zero_centered_gamma=zero_centered_gamma,
                                            epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3197

3198
            return local_x, local_mu, local_rsigma, global_updated_amax
3199

3200
        return mesh, sharded_impl, out_shardings, arg_shardings
3201

3202
3203
3204
3205
3206
3207
3208

register_primitive(LayerNormFwdFp8Primitive)


def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
                      scale: jnp.ndarray, scale_inv: jnp.ndarray, out_dtype: jnp.dtype,
                      zero_centered_gamma: bool, epsilon: float):
3209
    """
3210
    Wrapper for TE layernorm fwd (fp8 out)
3211
    """
3212
3213
3214
3215
3216
3217
3218
3219
3220
    return LayerNormFwdFp8Primitive.outer_primitive.bind(x,
                                                         gamma,
                                                         beta,
                                                         amax,
                                                         scale,
                                                         scale_inv,
                                                         out_dtype=out_dtype,
                                                         zero_centered_gamma=zero_centered_gamma,
                                                         epsilon=epsilon)
3221
3222


3223
class RmsNormFwdFp8Primitive(BasePrimitive):
3224
    """
3225
    RMS Normalization Forward FP8 Primitive
3226
    """
3227
3228
3229
3230
3231
    name = "te_rmsnorm_forward_fp8"
    multiple_results = True
    impl_static_args = (5, 6)    # out_dtype, epsilon
    inner_primitive = None
    outer_primitive = None
3232

3233
3234
    @staticmethod
    def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtype, epsilon):
3235
        """
3236
        RMSNorm fwd (fp8 out) abstract
3237
        """
3238
3239
        del epsilon
        x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
3240

3241
3242
3243
3244
        assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3245

3246
3247
        hidden_size = gamma_aval.size
        assert x_aval.size % hidden_size == 0
3248

3249
        rsigama_dtype = jnp.float32
3250

3251
3252
3253
        out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
        amax_aval = out_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3254

3255
        return out_aval, rsigma_aval, amax_aval
3256
3257

    @staticmethod
3258
    def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
3259
        """
3260
        RMSNorm fwd (fp8 out) lowering rules
3261
3262
        """

3263
3264
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
3265

3266
        x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
3267

3268
3269
3270
3271
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3272

3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        g_type = ir.RankedTensorType(gamma.type)
        g_shape = g_type.shape

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_rsigma_dtype = ir.F32Type.get()
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

        out_shape = x_shape
        hidden_size = reduce(operator.mul, g_shape)
        batch_shape = out_shape[:-1]
        batch_size = reduce(operator.mul, x_shape) // hidden_size
3290

3291
3292
3293
3294
3295
3296
3297
3298
3299
        out_types = [
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, gamma, amax, scale, scale_inv]
        operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

3300
3301
        sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))

3302
3303
3304
3305
3306
3307
3308
        opaque = transformer_engine_jax.pack_norm_descriptor(
            batch_size,
            hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(gamma_aval.dtype),
            False,    # RMSNorm doesn't support zero_centered_gamma
            epsilon,
3309
            sm_margin,
3310
3311
        )

3312
3313
3314
3315
3316
3317
3318
3319
        out = custom_caller(RmsNormFwdFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out

3320
    @staticmethod
3321
    def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon):
3322
        """
3323
        to describe implementation
3324
        """
3325
3326
3327
3328
3329
3330
3331
3332
3333
        assert RmsNormFwdFp8Primitive.inner_primitive is not None
        out, rsigma, amax = RmsNormFwdFp8Primitive.inner_primitive.bind(x,
                                                                        gamma,
                                                                        amax,
                                                                        scale,
                                                                        scale_inv,
                                                                        out_dtype=out_dtype,
                                                                        epsilon=epsilon)
        return out, rsigma, amax
3334

3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, epsilon):
        """
        to describe batch rules for vmap
        """
        _check_valid_batch_dims(batch_dims)
        assert RmsNormFwdFp8Primitive.outer_primitive is not None
        x, gamma, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
        out_bdims = x_bdim, x_bdim, amax_bdim
        return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                           gamma,
                                                           amax,
                                                           scale,
                                                           scale_inv,
                                                           out_dtype=out_dtype,
                                                           epsilon=epsilon), out_bdims
3352

3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
    @staticmethod
    def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del out_dtype, epsilon, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, rsigma_sharding, amax_sharding)
3367

3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
    @staticmethod
    def partition(out_dtype, epsilon, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        if x_spec[-1] is not None:
            warnings.warn(
                f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \
                f"Force to not shard the hidden dim, which might introduce extra collective ops, " \
                f"and hurt performance."
            )
        x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None))
        g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        out_sharding = x_sharding
        rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        fp8_meta_sharding = amax_sharding
        arg_shardings = (x_sharding, g_sharding) + (fp8_meta_sharding,) * 3
        out_shardings = (out_sharding, rsigma_sharding, amax_sharding)
3386

3387
3388
3389
3390
3391
        def sharded_impl(x, gamma, amax, scale, scale_inv):
            local_x, local_rsigma, local_amax= \
                RmsNormFwdFp8Primitive.impl(x, gamma, amax, scale, scale_inv,
                                            out_dtype=out_dtype, epsilon=epsilon)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3392

3393
            return local_x, local_rsigma, global_updated_amax
3394

3395
        return mesh, sharded_impl, out_shardings, arg_shardings
3396
3397


3398
register_primitive(RmsNormFwdFp8Primitive)
3399

3400
3401
3402

def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
                    scale_inv: jnp.ndarray, out_dtype: jnp.dtype, epsilon: float):
3403
    """
3404
    Wrapper for TE rmsnorm fwd (fp8 out)
3405
    """
3406
3407
3408
3409
3410
3411
3412
    return RmsNormFwdFp8Primitive.outer_primitive.bind(x,
                                                       gamma,
                                                       amax,
                                                       scale,
                                                       scale_inv,
                                                       out_dtype=out_dtype,
                                                       epsilon=epsilon)
3413
3414


3415
class GatedGeluFp8Primitive(BasePrimitive):
3416
    """
3417
    Gated Gelu FP8 Primitive
3418
    """
3419
    name = "te_gated_gelu_fp8"
3420
    multiple_results = True
3421
3422
3423
    impl_static_args = (4,)    #out_dtype
    inner_primitive = None
    outer_primitive = None
3424
3425

    @staticmethod
3426
    def abstract(x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype):
3427
        """
3428
        te_gated_gelu_p abstract
3429
        """
3430
3431
3432
3433
3434
3435
3436
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        # Currently only support casting to E4M3 only in C side.
        assert out_dtype == jnp.float8_e4m3fn
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
3437

3438
3439
3440
3441
3442
3443
        assert x_aval.shape[-2] == 2    # Assume x in (....., 2, hidden)
        hidden_size = x_aval.shape[-1]
        batch_shape = x_aval.shape[:-2]
        out_shape = (batch_shape) + (hidden_size,)
        out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
3444

3445
        return out_aval, updated_amax_aval
3446
3447

    @staticmethod
3448
    def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype):
3449
        """
3450
        te_gated_gelu_p lowering rules
3451
        """
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
3465

3466
3467
3468
3469
        hidden_size = ir_x_shape[-1]
        batch_shape = ir_x_shape[:-2]
        batch_size = reduce(operator.mul, batch_shape)
        out_shape = batch_shape + [hidden_size]
3470
        out_types = [
3471
3472
            ir.RankedTensorType.get(out_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
3473
        ]
3474
3475
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
3476
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
3477

3478
3479
3480
        opaque = transformer_engine_jax.pack_common_descriptor((batch_size, out_shape[-1]),
                                                               jax_dtype_to_te_dtype(x_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))
3481

3482
3483
3484
3485
3486
        out = custom_caller(GatedGeluFp8Primitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={1: 1})
3487
3488
3489
3490

        return out

    @staticmethod
3491
    def impl(x, amax, scale, scale_inv, out_dtype):
3492
        """
3493
        to describe implementation
3494
        """
3495
3496
3497
3498
3499
3500
3501
        assert GatedGeluFp8Primitive.inner_primitive is not None
        out, updated_amax = GatedGeluFp8Primitive.inner_primitive.bind(x,
                                                                       amax,
                                                                       scale,
                                                                       scale_inv,
                                                                       out_dtype=out_dtype)
        return out, updated_amax
3502
3503

    @staticmethod
3504
    def batcher(batched_args, batch_dims, *, out_dtype):
3505
        """
3506
        to describe batch rules for vmap
3507
        """
3508
3509
3510
3511
        _check_valid_batch_dims(batch_dims)
        assert GatedGeluFp8Primitive.outer_primitive is not None
        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, _, _ = batch_dims
3512

3513
3514
3515
3516
3517
3518
        out_bdims = x_bdim, amax_bdim
        return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                          amax,
                                                          scale,
                                                          scale_inv,
                                                          out_dtype=out_dtype), out_bdims
3519

3520
3521
3522
3523
3524
3525
3526
    @staticmethod
    def infer_sharding_from_operands(out_dtype, mesh, arg_infos, result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, amax_sharding)
3527

3528
3529
3530
3531
3532
3533
3534
3535
    @staticmethod
    def partition(out_dtype, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (out_sharding, amax_sharding)
3536

3537
3538
3539
3540
3541
3542
3543
        def sharded_impl(x, amax, scale, scale_inv):
            local_x, local_amax = GatedGeluFp8Primitive.impl(x,
                                                             amax,
                                                             scale,
                                                             scale_inv,
                                                             out_dtype=out_dtype)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
3544

3545
            return local_x, global_updated_amax
3546

3547
        return mesh, sharded_impl, out_shardings, arg_shardings
3548
3549


3550
register_primitive(GatedGeluFp8Primitive)
3551

3552
3553
3554

def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
                   out_dtype: jnp.dtype) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3555
    """
3556
3557
    gated gelu wrapper
    Return FP8(geglu(x))
3558
    """
3559
3560
3561
3562
3563
    return GatedGeluFp8Primitive.outer_primitive.bind(x,
                                                      amax,
                                                      scale,
                                                      scale_inv,
                                                      out_dtype=out_dtype)
3564
3565


3566
class DgatedGeluCastTransposePrimitive(BasePrimitive):
3567
    """
3568
    Dgated Gelu Cast Transpose Primitive
3569
    """
3570
    name = "te_dgated_gelu_cast_transpose"
3571
    multiple_results = True
3572
3573
3574
    impl_static_args = (5, 6)    # out_dtype, static_axis_boundary
    inner_primitive = None
    outer_primitive = None
3575
3576

    @staticmethod
3577
3578
    def abstract(dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype,
                 static_axis_boundary):
3579
        """
3580
        te_dgated_gelu_cast_transpose_p abstract
3581
        """
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert x_aval.shape[-2] == 2    # Linear + GeLU
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = _multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval
3597

3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary):
        """
        te_dgated_gelu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        assert x_shape[-2] == 2    # Linear + GeLU
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = _multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
        opaque = transformer_engine_jax.pack_common_descriptor(contracted_x_shape,
                                                               jax_dtype_to_te_dtype(dz_aval.dtype),
                                                               jax_dtype_to_te_dtype(out_dtype))

        out = custom_caller(DgatedGeluCastTransposePrimitive.name,
                            args,
                            opaque,
                            False,
                            operand_output_aliases={2: 2})

        return out
3647
3648

    @staticmethod
3649
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary):
3650
        """
3651
        to describe implementation
3652
        """
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
        assert DgatedGeluCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedGeluCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary)
        return out, t_out, updated_amax
3663

3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        _check_valid_batch_dims(batch_dims)
        assert DgatedGeluCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims
3674

3675
3676
3677
3678
        out_bdims = x_bdim, x_bdim, amax_bdim
        return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
            dz, x, amax, scale, scale_inv, out_dtype=out_dtype,
            static_axis_boundary=x_bdim), out_bdims
3679

3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
    @staticmethod
    def infer_sharding_from_operands(out_dtype, static_axis_boundary, mesh, arg_infos,
                                     result_infos):
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)
3690

3691
3692
3693
3694
3695
3696
3697
    @staticmethod
    def partition(out_dtype, static_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = _multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
3698

3699
3700
3701
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)
3702

3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedGeluCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_updated_amax
3714

3715
        return mesh, sharded_impl, out_shardings, arg_shardings
3716
3717


3718
register_primitive(DgatedGeluCastTransposePrimitive)
3719

3720
3721
3722
3723
3724

def dgated_gelu_cast_transpose(
        dz: jnp.ndarray, x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
        scale_inv: jnp.ndarray, out_dtype: TEDType,
        static_axis_boundary: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
3725
    """
3726
3727
    cast transpose d_gated_gelu fusion wrapper
    Return FP8(dgeglu(inputs))
3728
    """
3729
3730
3731
3732
3733
3734
3735
3736
    return DgatedGeluCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary)