transpose.py 45.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
5
import operator
6
7
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
8
from packaging import version
9

10
import jax
11
12
13
14
15
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

16
17
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
18
19
20
21
22
23
24
25
26
27

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
    te_dtype_to_jax_dtype,
    get_padded_spec,
    multidim_transpose,
28
    normalize_axis_boundary,
29
    is_ffi_enabled,
30
31
)
from .activation import ActivationEnum
32
33
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
35

36
37
38
39
40
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

41

42
43
44
45
46
47
48
__all__ = [
    "transpose",
    "cast_transpose",
    "dbias_cast_transpose",
    "dact_lu_dbias_cast_transpose",
    "dgated_act_lu_cast_transpose",
]
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
    """
    JAX native transpose implementation
    """
    axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
    return jnp.transpose(inputs, axes=axes)


def _jax_cast_transpose(
    inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
    """
    JAX native cast_transpose implementation
    """
    casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
    casted_transposed_output = _jax_transpose(
        casted_output, static_axis_boundary, transpose_axis_boundary
    )
    return casted_output, casted_transposed_output, updated_amax


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def _jax_dbias_cast_transpose(
    dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
):
    """
    JAX native dbias_cast_transpose implementation
    """
    casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
        dz,
        scale,
        amax,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
    )
    dbias = jnp.sum(
        dz,
        axis=tuple(
            range(
                transpose_axis_boundary
                if transpose_axis_boundary > 0
                else transpose_axis_boundary + dz.ndim
            )
        ),
        keepdims=False,
    )
    dbias = dbias.ravel()  # C++ function returns an 1D array for dbias
    return casted_dz, cast_transposed_dz, dbias, updated_amax


101
102
103
104
class TransposePrimitive(BasePrimitive):
    """
    Transpose Primitive
    """
105

106
107
108
109
110
111
112
113
114
115
116
    name = "te_transpose"
    multiple_results = False
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose abstract
        """
117
118
119
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
120
121
122
123
124
125
126
127
128
129
130
131
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)

        return xt_aval

    @staticmethod
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose cuda lowering
        """

        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
132
133
134
135
136
            jnp.float32,
            jnp.float16,
            jnp.bfloat16,
            jnp.float8_e4m3fn,
            jnp.float8_e5m2,
137
138
        ]

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        if is_ffi_enabled():
            name = "te_transpose_ffi"
            out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
            if static_axis_boundary >= 0:
                for i in range(static_axis_boundary + 1):
                    assert ir_x_shape[i] == 1

            transposed_x_shape = multidim_transpose(
                ir_x_shape, static_axis_boundary, transpose_axis_boundary
            )

            out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
            operands = [x]
            operand_shapes = [ir_x_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
            contracted_x_shape = (
                reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
            )
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape, te_dtype, te_dtype
            )

            out = custom_caller(TransposePrimitive.name, args, opaque, False)
169

170
        return out
171
172
173
174
175
176
177

    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
178
179
180
181
182
        transposed_x = TransposePrimitive.inner_primitive.bind(
            x,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
183
184
185
186
187
188
189
190
        return transposed_x

    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

191
192
        (x,) = batched_args
        (x_bdim,) = batch_dims
193
194
195

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
196
        transpose_axis_boundary += 1  # Plus batch dim
197
198

        out_bdims = x_bdim
199
200
201
202
203
204
        return (
            TransposePrimitive.outer_primitive.bind(
                x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary
            ),
            out_bdims,
        )
205
206

    @staticmethod
207
208
209
    def infer_sharding_from_operands(
        static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding

    @staticmethod
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding

225
226
227
228
229
        impl = partial(
            TransposePrimitive.impl,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
230
231
232
233
234
235
236

        return mesh, impl, out_shardings, arg_shardings


register_primitive(TransposePrimitive)


237
238
239
def transpose(
    x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int
) -> jnp.ndarray:
240
241
242
    """
    transpose wrapper
    """
243
244
    if not TransposePrimitive.enabled():
        return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
245
246
247
248
249
    return TransposePrimitive.outer_primitive.bind(
        x,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
    )
250
251
252
253
254
255


class CastTransposePrimitive(BasePrimitive):
    """
    Cast Transpose Primitive
    """
256

257
258
259
260
261
262
263
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
264
265
266
267
268
269
270
271
272
273
    def abstract(
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
274
275
276
277
278
279
280
281
282
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

283
284
285
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
286
287
288
289
290
291
292
293

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval

    @staticmethod
294
295
296
    def lowering(
        ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
297
298
299
300
301
302
303
304
        """
        te_cast_transpose_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        if is_ffi_enabled():
            name = "te_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
                ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
            )
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            if static_axis_boundary >= 0:
                for i in range(static_axis_boundary + 1):
                    assert ir_x_shape[i] == 1
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape

            transposed_x_shape = multidim_transpose(
                ir_x_shape, static_axis_boundary, transpose_axis_boundary
            )

            out_types = [
                ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [x, amax, scale, scale_inv]
            operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            contracted_x_shape = (
                reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
            )
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
            )
            out = custom_caller(
                CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
            )
348
349
350
351
352
353
354
355
        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
        """
        te_cast_transpose implementation
        """
        assert CastTransposePrimitive.inner_primitive is not None
356
357
358
359
360
361
362
363
364
        casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
365
366
367
        return casted_x, casted_transposed_x, updated_amax

    @staticmethod
368
369
370
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
371
372
373
374
375
376
377
378
379
        check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
380
        transpose_axis_boundary += 1  # Plus batch dim
381
382

        out_bdims = x_bdim, x_bdim, amax_bdim
383
384
385
386
387
388
389
390
391
392
393
394
        return (
            CastTransposePrimitive.outer_primitive.bind(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
395
396

    @staticmethod
397
398
399
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
400
401
402
403
404
405
406
407
408
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
409
410
411
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
412
413
414
415
416
417
418
419
420
421
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
422
423
424
425
426
427
428
429
430
            local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary,
            )
431
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
432
433
434
435
436
437
438
439
440

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


441
442
443
444
445
446
447
448
449
def cast_transpose(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    static_axis_boundary: int,
    transpose_axis_boundary: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
450
451
452
453
    """
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
    """
454
455
    if not CastTransposePrimitive.enabled():
        return _jax_cast_transpose(
456
            x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
457
        )
458
459
460
461
462
463
464
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
465
466
        transpose_axis_boundary=transpose_axis_boundary,
    )
467
468
469
470
471
472


class DBiasCastTransposePrimitive(BasePrimitive):
    """
    DBias Cast Transpose Primitive
    """
473

474
475
476
477
478
479
480
481
    name = "te_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
482
483
484
485
486
487
488
489
490
491
    def abstract(
        dz_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
492
493
494
495
496
497
498
499
500
501
502
503
504
        """
        te_dbias_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
        t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
        out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

505
        dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
506
507
508
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
509
        (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes(
510
511
512
            dz_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(dz_aval.dtype),
513
514
515
516
            jax_dtype_to_te_dtype(out_dtype),
        )
        wkspace_aval = dz_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
517
518
519
520
521
522
523
524
525
526
        )

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_cast_transpose_p outer abstract
        """

527
528
529
        out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
530
531
532
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
533
534
535
    def lowering(
        ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
536
537
538
539
540
541
542
543
        """
        te_dbias_cast_transpose_p lowering rules
        """
        dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        if is_ffi_enabled():
            name = "te_dbias_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})(
                ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
            )
        else:
            ir_dz_type = ir.RankedTensorType(dz.type)
            ir_dz_shape = ir_dz_type.shape
            batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
            ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
            contracted_dz_shape = (batch_size, ir_hidden_size)
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape
            transposed_dz_shape = multidim_transpose(
                ir_dz_shape, static_axis_boundary, transpose_axis_boundary
            )
            dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
565

566
            wkspace_aval = ctx.avals_out[-1]
567

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
            out_types = [
                ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
                ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
                ir.RankedTensorType.get(
                    wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
                ),
            ]
            operands = [dz, amax, scale, scale_inv]
            operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
            opaque = transformer_engine_jax.pack_common_wk_descriptor(
                contracted_dz_shape,
                wkspace_aval.shape,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
            )
587

588
589
590
            out = custom_caller(
                DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
            )
591
592
593
594

        return out

    @staticmethod
595
    def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
596
597
598
599
600
601
602
603
604
605
606
        """
        to describe implementation
        """
        assert DBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
607
608
            transpose_axis_boundary=transpose_axis_boundary,
        )
609
610
611
        return out, t_out, dbias, updated_amax

    @staticmethod
612
613
614
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
615
616
617
618
619
620
621
622
623
624
625
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DBiasCastTransposePrimitive.outer_primitive is not None
        dz, amax, scale, scale_inv = batched_args
        dz_bdim, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
626
        transpose_axis_boundary += 1  # Plus batch dim
627
628

        out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
629
630
631
632
633
634
635
636
637
638
639
640
        return (
            DBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=dz_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
641
642

    @staticmethod
643
644
645
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
646
647
648
649
650
651
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
652
653
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
654
655
656
657
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
658
659
660
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
661
662
663
664
665
666
667
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
668
669
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
670
671
672

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
673
674
675
676
677
678
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
679
680
681
682
683
684
685
686
687

        def sharded_impl(dz, amax, scale, scale_inv):
            local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
688
689
                transpose_axis_boundary=transpose_axis_boundary,
            )
690
691
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DBiasCastTransposePrimitive)


def dbias_cast_transpose(
    dz: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
707
708
    transpose_axis_boundary: int = -1,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
709
710
711
712
713
    """
    cast transpose dbias partial fusion wrapper
    Return FP8(inputs), dbias
    """
    if static_axis_boundary < 0:
714
        static_axis_boundary = -1  # means no static axes
715

716
    if not DBiasCastTransposePrimitive.enabled():
717
718
        return _jax_dbias_cast_transpose(
            dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
719
720
        )

721
722
723
724
725
726
727
    return DBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
728
729
        transpose_axis_boundary=transpose_axis_boundary,
    )
730
731
732
733
734
735


class DActLuDBiasCastTransposePrimitive(BasePrimitive):
    """
    DActLu DBias Cast Transpose Primitive
    """
736

737
738
    name = "te_dact_lu_dbias_cast_transpose"
    multiple_results = True
739
740
    # out_dtype, static_axis_boundary, act_enum
    impl_static_args = (5, 6, 7)
741
742
743
744
    inner_primitive = None
    outer_primitive = None

    @staticmethod
745
746
747
748
749
750
751
752
753
754
755
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
756
757
758
759
760
761
762
763
764
765
766
767
        """
        te_dact_lu_dbais_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
768
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
769
770
771
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

772
        dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
773
774
775
776
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

777
        (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
778
779
780
781
782
            x_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
        )
783
784
785
        wkspace_aval = x_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
786
787
788
789
790
791
792
793
794

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dact_lu_dbais_cast_transpose_p outer abstract
        """

795
796
797
        out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
798
799
800
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
801
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
802
803
804
805
806
807
808
809
810
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        if is_ffi_enabled():
            name = "te_dact_lu_dbias_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})(
                ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
            )
        else:
            ir_dz_type = ir.RankedTensorType(dz.type)
            ir_dz_shape = ir_dz_type.shape
            x_type = ir.RankedTensorType(x.type)
            x_shape = x_type.shape
            dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
            x_batch_size = reduce(operator.mul, x_shape[:-2])
            assert dz_batch_szie == x_batch_size
            ir_hidden_szie = ir_dz_shape[-1]
            contracted_x_shape = (x_batch_size, ir_hidden_szie)
826

827
828
829
830
831
832
833
834
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape
            transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
            dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
835

836
            wkspace_aval = ctx.avals_out[-1]
837

838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
            out_types = [
                ir.RankedTensorType.get(x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
                ir.RankedTensorType.get(
                    wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
                ),
            ]
            operands = [dz, x, amax, scale, scale_inv]
            operand_shapes = [
                ir_dz_shape,
                x_shape,
                ir_amax_shape,
                ir_scale_shape,
                ir_scale_inv_shape,
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
            opaque = transformer_engine_jax.pack_common_wk_descriptor(
                contracted_x_shape,
                wkspace_aval.shape,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
                act_enum,
            )

            out = custom_caller(
                DActLuDBiasCastTransposePrimitive.name,
                args,
                opaque,
                False,
                operand_output_aliases={2: 3},
            )
872
873
874
875

        return out

    @staticmethod
876
877
878
879
880
881
882
883
884
885
    def impl(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype,
        static_axis_boundary,
        act_enum,
    ):
886
887
888
889
890
891
892
893
894
895
896
897
        """
        to describe implementation
        """
        assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
898
899
            act_enum=act_enum,
        )
900
901
902
        return out, t_out, dbias, updated_amax

    @staticmethod
903
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
904
905
906
907
908
909
910
911
912
913
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
914
915
916
917
918
919
920
921
922
923
924
925
926
        return (
            DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                act_enum=act_enum,
            ),
            out_bdims,
        )
927
928

    @staticmethod
929
930
931
932
933
934
935
936
    def infer_sharding_from_operands(
        out_dtype,
        static_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
937
938
939
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
940
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
941
942
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
943
944
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
945
946
947
948
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
949
950
951
952
953
954
955
956
    def partition(
        out_dtype,
        static_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
957
958
959
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
960
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
961
962
963
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
964
965
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
966
967
968

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
969
970
971
972
973
974
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
975
976

        def sharded_impl(dz, x, amax, scale, scale_inv):
977
978
979
980
981
982
983
984
985
986
987
988
            local_out, local_t_out, local_dbias, local_amax = (
                DActLuDBiasCastTransposePrimitive.impl(
                    dz,
                    x,
                    amax,
                    scale,
                    scale_inv,
                    out_dtype=out_dtype,
                    static_axis_boundary=static_axis_boundary,
                    act_enum=act_enum,
                )
            )
989
990
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuDBiasCastTransposePrimitive)


def dact_lu_dbias_cast_transpose(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
1007
1008
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1009
1010
1011
1012
1013
1014
    """
    cast transpose dact_lu and dbias fusion wrapper
    Return FP8(dact_lu(inputs)), dbias
    ONLY support non-gated activation type
    """
    if static_axis_boundary < 0:
1015
        static_axis_boundary = -1  # means no static axes
1016

1017
1018
1019
    if not DActLuDBiasCastTransposePrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
        (dx,) = vjp_func(dz)
1020
1021
1022
        transpose_axis_boundary = -2
        return _jax_dbias_cast_transpose(
            dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
1023
1024
        )

1025
1026
1027
1028
1029
1030
1031
1032
1033
    act_type_id = ActivationEnum[activation_type]
    return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
1034
1035
        act_enum=act_type_id,
    )
1036
1037
1038
1039
1040
1041


class DgatedActLuCastTransposePrimitive(BasePrimitive):
    """
    Dgated ActLu Cast Transpose Primitive
    """
1042

1043
1044
    name = "te_dgated_act_lu_cast_transpose"
    multiple_results = True
1045
    impl_static_args = (5, 6, 7)  # out_dtype, static_axis_boundary, act_enum
1046
1047
1048
1049
    inner_primitive = None
    outer_primitive = None

    @staticmethod
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
1061
1062
1063
1064
1065
1066
        """
        te_dgated_act_lu_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
1067
        assert x_aval.shape[-2] == 2  # Linear + GeLU
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval

    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
        if is_ffi_enabled():
            name = "te_dgated_act_lu_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
                ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
            )
        else:
            ir_dz_type = ir.RankedTensorType(dz.type)
            ir_dz_shape = ir_dz_type.shape
            x_type = ir.RankedTensorType(x.type)
            x_shape = x_type.shape
            dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
            x_batch_size = reduce(operator.mul, x_shape[:-2])
            assert dz_batch_szie == x_batch_size
            assert x_shape[-2] == 2  # Linear + GeLU
            ir_hidden_szie = ir_dz_shape[-1]
            gi_hidden_size = x_shape[-1]
            assert ir_hidden_szie == gi_hidden_size
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape
            transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
            out_types = [
                ir.RankedTensorType.get(x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [dz, x, amax, scale, scale_inv]
            operand_shapes = [
                ir_dz_shape,
                x_shape,
                ir_amax_shape,
                ir_scale_shape,
                ir_scale_inv_shape,
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
            contracted_x_shape = (x_batch_size, x_shape[-1])
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                act_enum,
            )
1136

1137
1138
1139
1140
1141
1142
1143
            out = custom_caller(
                DgatedActLuCastTransposePrimitive.name,
                args,
                opaque,
                False,
                operand_output_aliases={2: 2},
            )
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160

        return out

    @staticmethod
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
        """
        to describe implementation
        """
        assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
1161
1162
            act_enum=act_enum,
        )
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        return out, t_out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, amax_bdim
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        return (
            DgatedActLuCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                act_enum=act_enum,
            ),
            out_bdims,
        )
1190
1191

    @staticmethod
1192
1193
1194
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos
    ):
1195
1196
1197
1198
1199
1200
1201
1202
1203
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)

    @staticmethod
1204
    def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos):
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
1224
1225
                act_enum=act_enum,
            )
1226
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
1227
1228
1229
1230
1231
1232
1233
1234
1235
            return local_out, local_t_out, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DgatedActLuCastTransposePrimitive)


def dgated_act_lu_cast_transpose(
1236
1237
1238
1239
1240
1241
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
1242
    static_axis_boundary: int,
1243
1244
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1245
1246
1247
1248
1249
    """
    cast transpose d_gated_act_lu fusion wrapper
    Return FP8(dgated_act_lu(inputs))
    """
    act_type_id = ActivationEnum[activation_type]
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
    if not DgatedActLuCastTransposePrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
        (dx,) = vjp_func(dz)
        return _jax_cast_transpose(
            dx,
            scale,
            amax,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=-2,
        )
1261
1262
1263
1264
1265
1266
1267
1268
    return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
1269
1270
        act_enum=act_type_id,
    )