transpose.py 44.7 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
import operator

9
import jax
10
11
12
13
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
14
from jax.extend import ffi
15
16
17
18
19
20
21
22
23
24
25
26
27

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
    te_dtype_to_jax_dtype,
    get_padded_spec,
    multidim_transpose,
28
    normalize_axis_boundary,
29
    is_ffi_enabled,
30
31
)
from .activation import ActivationEnum
32
33
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
35
36


37
38
39
40
41
42
43
__all__ = [
    "transpose",
    "cast_transpose",
    "dbias_cast_transpose",
    "dact_lu_dbias_cast_transpose",
    "dgated_act_lu_cast_transpose",
]
44
45


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
    """
    JAX native transpose implementation
    """
    axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
    return jnp.transpose(inputs, axes=axes)


def _jax_cast_transpose(
    inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
    """
    JAX native cast_transpose implementation
    """
    casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
    casted_transposed_output = _jax_transpose(
        casted_output, static_axis_boundary, transpose_axis_boundary
    )
    return casted_output, casted_transposed_output, updated_amax


67
68
69
70
class TransposePrimitive(BasePrimitive):
    """
    Transpose Primitive
    """
71

72
73
74
75
76
77
78
79
80
81
82
    name = "te_transpose"
    multiple_results = False
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose abstract
        """
83
84
85
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
86
87
88
89
90
91
92
93
94
95
96
97
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)

        return xt_aval

    @staticmethod
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose cuda lowering
        """

        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
98
99
100
101
102
            jnp.float32,
            jnp.float16,
            jnp.bfloat16,
            jnp.float8_e4m3fn,
            jnp.float8_e5m2,
103
104
        ]

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        if is_ffi_enabled():
            name = "te_transpose_ffi"
            out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
            if static_axis_boundary >= 0:
                for i in range(static_axis_boundary + 1):
                    assert ir_x_shape[i] == 1

            transposed_x_shape = multidim_transpose(
                ir_x_shape, static_axis_boundary, transpose_axis_boundary
            )

            out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
            operands = [x]
            operand_shapes = [ir_x_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
            contracted_x_shape = (
                reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
            )
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape, te_dtype, te_dtype
            )

            out = custom_caller(TransposePrimitive.name, args, opaque, False)
135

136
        return out
137
138
139
140
141
142
143

    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
144
145
146
147
148
        transposed_x = TransposePrimitive.inner_primitive.bind(
            x,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
149
150
151
152
153
154
155
156
        return transposed_x

    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

157
158
        (x,) = batched_args
        (x_bdim,) = batch_dims
159
160
161

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
162
        transpose_axis_boundary += 1  # Plus batch dim
163
164

        out_bdims = x_bdim
165
166
167
168
169
170
        return (
            TransposePrimitive.outer_primitive.bind(
                x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary
            ),
            out_bdims,
        )
171
172

    @staticmethod
173
174
175
    def infer_sharding_from_operands(
        static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding

    @staticmethod
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding

191
192
193
194
195
        impl = partial(
            TransposePrimitive.impl,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
196
197
198
199
200
201
202

        return mesh, impl, out_shardings, arg_shardings


register_primitive(TransposePrimitive)


203
204
205
def transpose(
    x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int
) -> jnp.ndarray:
206
207
208
    """
    transpose wrapper
    """
209
210
    if not TransposePrimitive.enabled():
        return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
211
212
213
214
215
    return TransposePrimitive.outer_primitive.bind(
        x,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
    )
216
217
218
219
220
221


class CastTransposePrimitive(BasePrimitive):
    """
    Cast Transpose Primitive
    """
222

223
224
225
226
227
228
229
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
230
231
232
233
234
235
236
237
238
239
    def abstract(
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
240
241
242
243
244
245
246
247
248
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

249
250
251
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
252
253
254
255
256
257
258
259

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval

    @staticmethod
260
261
262
    def lowering(
        ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
263
264
265
266
267
268
269
270
        """
        te_cast_transpose_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        if is_ffi_enabled():
            name = "te_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
                ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
            )
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            if static_axis_boundary >= 0:
                for i in range(static_axis_boundary + 1):
                    assert ir_x_shape[i] == 1
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape

            transposed_x_shape = multidim_transpose(
                ir_x_shape, static_axis_boundary, transpose_axis_boundary
            )

            out_types = [
                ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [x, amax, scale, scale_inv]
            operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            contracted_x_shape = (
                reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
            )
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
            )
            out = custom_caller(
                CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
            )
314
315
316
317
318
319
320
321
        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
        """
        te_cast_transpose implementation
        """
        assert CastTransposePrimitive.inner_primitive is not None
322
323
324
325
326
327
328
329
330
        casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
331
332
333
        return casted_x, casted_transposed_x, updated_amax

    @staticmethod
334
335
336
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
337
338
339
340
341
342
343
344
345
        check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
346
        transpose_axis_boundary += 1  # Plus batch dim
347
348

        out_bdims = x_bdim, x_bdim, amax_bdim
349
350
351
352
353
354
355
356
357
358
359
360
        return (
            CastTransposePrimitive.outer_primitive.bind(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
361
362

    @staticmethod
363
364
365
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
366
367
368
369
370
371
372
373
374
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
375
376
377
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
378
379
380
381
382
383
384
385
386
387
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
388
389
390
391
392
393
394
395
396
            local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary,
            )
397
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
398
399
400
401
402
403
404
405
406

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


407
408
409
410
411
412
413
414
415
def cast_transpose(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    static_axis_boundary: int,
    transpose_axis_boundary: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
416
417
418
419
    """
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
    """
420
421
422
423
424
425
426
427
428
    if not CastTransposePrimitive.enabled():
        return _jax_cast_transpose(
            x,
            scale,
            amax,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
429
430
431
432
433
434
435
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
436
437
        transpose_axis_boundary=transpose_axis_boundary,
    )
438
439
440
441
442
443


class DBiasCastTransposePrimitive(BasePrimitive):
    """
    DBias Cast Transpose Primitive
    """
444

445
446
447
448
449
450
451
452
    name = "te_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
453
454
455
456
457
458
459
460
461
462
    def abstract(
        dz_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
463
464
465
466
467
468
469
470
471
472
473
474
475
        """
        te_dbias_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
        t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
        out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

476
        dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
477
478
479
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
480
        (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes(
481
482
483
            dz_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(dz_aval.dtype),
484
485
486
487
            jax_dtype_to_te_dtype(out_dtype),
        )
        wkspace_aval = dz_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
488
489
490
491
492
493
494
495
496
497
        )

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_cast_transpose_p outer abstract
        """

498
499
500
        out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
501
502
503
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
504
505
506
    def lowering(
        ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
        """
        te_dbias_cast_transpose_p lowering rules
        """
        dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
        ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
        contracted_dz_shape = (batch_size, ir_hidden_size)
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
526
527
528
529
        transposed_dz_shape = multidim_transpose(
            ir_dz_shape, static_axis_boundary, transpose_axis_boundary
        )
        dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
530
531
532
533
534
535
536
537
538
539
540
541
542
543

        wkspace_aval = ctx.avals_out[-1]

        out_types = [
            ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
            ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
        ]
        operands = [dz, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_common_wk_descriptor(
544
545
546
547
548
549
            contracted_dz_shape,
            wkspace_aval.shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
        )
550

551
552
553
        out = custom_caller(
            DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
        )
554
555
556
557

        return out

    @staticmethod
558
    def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
559
560
561
562
563
564
565
566
567
568
569
        """
        to describe implementation
        """
        assert DBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
570
571
            transpose_axis_boundary=transpose_axis_boundary,
        )
572
573
574
        return out, t_out, dbias, updated_amax

    @staticmethod
575
576
577
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
578
579
580
581
582
583
584
585
586
587
588
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DBiasCastTransposePrimitive.outer_primitive is not None
        dz, amax, scale, scale_inv = batched_args
        dz_bdim, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
589
        transpose_axis_boundary += 1  # Plus batch dim
590
591

        out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
592
593
594
595
596
597
598
599
600
601
602
603
        return (
            DBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=dz_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
604
605

    @staticmethod
606
607
608
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
609
610
611
612
613
614
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
615
616
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
617
618
619
620
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
621
622
623
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
624
625
626
627
628
629
630
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
631
632
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
633
634
635

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
636
637
638
639
640
641
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
642
643
644
645
646
647
648
649
650

        def sharded_impl(dz, amax, scale, scale_inv):
            local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
651
652
                transpose_axis_boundary=transpose_axis_boundary,
            )
653
654
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DBiasCastTransposePrimitive)


def dbias_cast_transpose(
    dz: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
670
671
    transpose_axis_boundary: int = -1,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
672
673
674
675
676
    """
    cast transpose dbias partial fusion wrapper
    Return FP8(inputs), dbias
    """
    if static_axis_boundary < 0:
677
        static_axis_boundary = -1  # means no static axes
678

679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
    if not DBiasCastTransposePrimitive.enabled():
        casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
            dz,
            scale,
            amax,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
        dbias = jnp.sum(
            dz,
            axis=tuple(
                range(
                    transpose_axis_boundary
                    if transpose_axis_boundary > 0
                    else transpose_axis_boundary + dz.ndim
                )
            ),
            keepdims=False,
        )
        return casted_dz, cast_transposed_dz, dbias, updated_amax

701
702
703
704
705
706
707
    return DBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
708
709
        transpose_axis_boundary=transpose_axis_boundary,
    )
710
711
712
713
714
715


class DActLuDBiasCastTransposePrimitive(BasePrimitive):
    """
    DActLu DBias Cast Transpose Primitive
    """
716

717
718
719
720
721
722
723
724
    name = "te_dact_lu_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum
    impl_static_args = (5, 6, 7, 8)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
725
726
727
728
729
730
731
732
733
734
735
736
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
737
738
739
740
741
742
743
744
745
746
747
748
        """
        te_dact_lu_dbais_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
749
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
750
751
752
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

753
        dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
754
755
756
757
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

758
        (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
759
760
761
762
763
            x_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
        )
764
765
766
        wkspace_aval = x_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
767
768
769
770
771
772
773
774
775

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dact_lu_dbais_cast_transpose_p outer abstract
        """

776
777
778
        out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
779
780
781
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
782
783
784
785
786
787
788
789
790
791
792
793
794
    def lowering(
        ctx,
        dz,
        x,
        amax,
        scale,
        scale_inv,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum
    ):
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        ir_hidden_szie = ir_dz_shape[-1]
        contracted_x_shape = (x_batch_size, ir_hidden_szie)

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
820
821
822
823
        transposed_x_shape = multidim_transpose(
            x_shape, static_axis_boundary, transpose_axis_boundary
        )
        dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
824
825
826
827
828
829
830
831
832
833
834
835
836
837

        wkspace_aval = ctx.avals_out[-1]

        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_common_wk_descriptor(
838
839
840
841
842
843
844
            contracted_x_shape,
            wkspace_aval.shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            act_enum,
        )
845

846
847
848
849
850
851
852
        out = custom_caller(
            DActLuDBiasCastTransposePrimitive.name,
            args,
            opaque,
            False,
            operand_output_aliases={2: 3},
        )
853
854
855
856

        return out

    @staticmethod
857
858
859
860
861
862
863
864
865
866
867
    def impl(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum,
    ):
868
869
870
871
872
873
874
875
876
877
878
879
880
        """
        to describe implementation
        """
        assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
881
882
            act_enum=act_enum,
        )
883
884
885
        return out, t_out, dbias, updated_amax

    @staticmethod
886
887
888
889
890
891
892
893
894
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum
    ):
895
896
897
898
899
900
901
902
903
904
905
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
906
        transpose_axis_boundary += 1  # Plus batch dim
907
908

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
909
910
911
912
913
914
915
916
917
918
919
920
921
922
        return (
            DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
                act_enum=act_enum,
            ),
            out_bdims,
        )
923
924

    @staticmethod
925
926
927
928
929
930
931
932
933
    def infer_sharding_from_operands(
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
934
935
936
937
938
939
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
940
941
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
942
943
944
945
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
946
947
948
949
950
951
952
953
954
    def partition(
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
955
956
957
958
959
960
961
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
962
963
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
964
965
966

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
967
968
969
970
971
972
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
973
974

        def sharded_impl(dz, x, amax, scale, scale_inv):
975
976
977
978
979
980
981
982
983
984
985
986
987
            local_out, local_t_out, local_dbias, local_amax = (
                DActLuDBiasCastTransposePrimitive.impl(
                    dz,
                    x,
                    amax,
                    scale,
                    scale_inv,
                    out_dtype=out_dtype,
                    static_axis_boundary=static_axis_boundary,
                    transpose_axis_boundary=transpose_axis_boundary,
                    act_enum=act_enum,
                )
            )
988
989
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuDBiasCastTransposePrimitive)


def dact_lu_dbias_cast_transpose(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
    transpose_axis_boundary: int = -1,
1007
1008
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1009
1010
1011
1012
1013
1014
    """
    cast transpose dact_lu and dbias fusion wrapper
    Return FP8(dact_lu(inputs)), dbias
    ONLY support non-gated activation type
    """
    if static_axis_boundary < 0:
1015
        static_axis_boundary = -1  # means no static axes
1016

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    if not DActLuDBiasCastTransposePrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
        (dx,) = vjp_func(dz)
        casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose(
            dx,
            scale,
            amax,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
        dbias = jnp.squeeze(
            jnp.sum(
                dx,
                axis=tuple(
                    range(
                        transpose_axis_boundary
                        if transpose_axis_boundary > 0
                        else transpose_axis_boundary + dx.ndim
                    )
                ),
            )
        )
        return casted_dx, cast_transposed_dx, dbias, updated_amax

1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    act_type_id = ActivationEnum[activation_type]
    return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
1052
1053
        act_enum=act_type_id,
    )
1054
1055
1056
1057
1058
1059


class DgatedActLuCastTransposePrimitive(BasePrimitive):
    """
    Dgated ActLu Cast Transpose Primitive
    """
1060

1061
1062
    name = "te_dgated_act_lu_cast_transpose"
    multiple_results = True
1063
    impl_static_args = (5, 6, 7)  # out_dtype, static_axis_boundary, act_enum
1064
1065
1066
1067
    inner_primitive = None
    outer_primitive = None

    @staticmethod
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
1079
1080
1081
1082
1083
1084
        """
        te_dgated_act_lu_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
1085
        assert x_aval.shape[-2] == 2  # Linear + GeLU
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval

    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
1116
        assert x_shape[-2] == 2  # Linear + GeLU
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
        opaque = transformer_engine_jax.pack_common_descriptor(
            contracted_x_shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
1140
1141
            act_enum,
        )
1142

1143
1144
1145
1146
1147
1148
1149
        out = custom_caller(
            DgatedActLuCastTransposePrimitive.name,
            args,
            opaque,
            False,
            operand_output_aliases={2: 2},
        )
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166

        return out

    @staticmethod
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
        """
        to describe implementation
        """
        assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
1167
1168
            act_enum=act_enum,
        )
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        return out, t_out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, amax_bdim
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
        return (
            DgatedActLuCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                act_enum=act_enum,
            ),
            out_bdims,
        )
1196
1197

    @staticmethod
1198
1199
1200
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos
    ):
1201
1202
1203
1204
1205
1206
1207
1208
1209
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)

    @staticmethod
1210
    def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos):
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
1230
1231
                act_enum=act_enum,
            )
1232
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
1233
1234
1235
1236
1237
1238
1239
1240
1241
            return local_out, local_t_out, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DgatedActLuCastTransposePrimitive)


def dgated_act_lu_cast_transpose(
1242
1243
1244
1245
1246
1247
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
1248
    static_axis_boundary: int,
1249
1250
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1251
1252
1253
1254
1255
    """
    cast transpose d_gated_act_lu fusion wrapper
    Return FP8(dgated_act_lu(inputs))
    """
    act_type_id = ActivationEnum[activation_type]
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
    if not DgatedActLuCastTransposePrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
        (dx,) = vjp_func(dz)
        return _jax_cast_transpose(
            dx,
            scale,
            amax,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=-2,
        )
1267
1268
1269
1270
1271
1272
1273
1274
    return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
1275
1276
        act_enum=act_type_id,
    )