"docs/api/c/multi_tensor.rst" did not exist on "c9ea6be92948e1ec553037f1a04900617b9f7f6b"
transpose.py 40.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
import operator

import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
    te_dtype_to_jax_dtype,
    get_padded_spec,
    multidim_transpose,
26
    normalize_axis_boundary,
27
28
)
from .activation import ActivationEnum
29
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
30
31


32
33
34
35
36
37
38
__all__ = [
    "transpose",
    "cast_transpose",
    "dbias_cast_transpose",
    "dact_lu_dbias_cast_transpose",
    "dgated_act_lu_cast_transpose",
]
39
40
41
42
43
44


class TransposePrimitive(BasePrimitive):
    """
    Transpose Primitive
    """
45

46
47
48
49
50
51
52
53
54
55
56
    name = "te_transpose"
    multiple_results = False
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose abstract
        """
57
58
59
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
60
61
62
63
64
65
66
67
68
69
70
71
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)

        return xt_aval

    @staticmethod
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose cuda lowering
        """

        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
72
73
74
75
76
            jnp.float32,
            jnp.float16,
            jnp.bfloat16,
            jnp.float8_e4m3fn,
            jnp.float8_e5m2,
77
78
79
80
81
82
83
84
85
        ]

        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1

86
87
88
        transposed_x_shape = multidim_transpose(
            ir_x_shape, static_axis_boundary, transpose_axis_boundary
        )
89
90
91
92
93
94
95

        out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
        operands = [x]
        operand_shapes = [ir_x_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
96
97
98
99
100
101
102
        contracted_x_shape = (
            reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
            reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
        )
        opaque = transformer_engine_jax.pack_common_descriptor(
            contracted_x_shape, te_dtype, te_dtype
        )
103
104
105
106
107
108
109
110
111
112
113

        out = custom_caller(TransposePrimitive.name, args, opaque, False)

        return [out]

    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
114
115
116
117
118
        transposed_x = TransposePrimitive.inner_primitive.bind(
            x,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
119
120
121
122
123
124
125
126
        return transposed_x

    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

127
128
        (x,) = batched_args
        (x_bdim,) = batch_dims
129
130
131

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
132
        transpose_axis_boundary += 1  # Plus batch dim
133
134

        out_bdims = x_bdim
135
136
137
138
139
140
        return (
            TransposePrimitive.outer_primitive.bind(
                x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary
            ),
            out_bdims,
        )
141
142

    @staticmethod
143
144
145
    def infer_sharding_from_operands(
        static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding

    @staticmethod
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding

161
162
163
164
165
        impl = partial(
            TransposePrimitive.impl,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
166
167
168
169
170
171
172

        return mesh, impl, out_shardings, arg_shardings


register_primitive(TransposePrimitive)


173
174
175
def transpose(
    x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int
) -> jnp.ndarray:
176
177
178
    """
    transpose wrapper
    """
179
180
181
182
183
    return TransposePrimitive.outer_primitive.bind(
        x,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
    )
184
185
186
187
188
189


class CastTransposePrimitive(BasePrimitive):
    """
    Cast Transpose Primitive
    """
190

191
192
193
194
195
196
197
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
198
199
200
201
202
203
204
205
206
207
    def abstract(
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
208
209
210
211
212
213
214
215
216
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

217
218
219
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
220
221
222
223
224
225
226
227

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval

    @staticmethod
228
229
230
    def lowering(
        ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        """
        te_cast_transpose_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_x_type = ir.RankedTensorType(x.type)
        ir_x_shape = ir_x_type.shape
        if static_axis_boundary >= 0:
            for i in range(static_axis_boundary + 1):
                assert ir_x_shape[i] == 1
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape

251
252
253
        transposed_x_shape = multidim_transpose(
            ir_x_shape, static_axis_boundary, transpose_axis_boundary
        )
254
255
256
257
258
259
260
261
262
263

        out_types = [
            ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [x, amax, scale, scale_inv]
        operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

264
265
266
267
268
269
270
271
272
        contracted_x_shape = (
            reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
            reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
        )
        opaque = transformer_engine_jax.pack_common_descriptor(
            contracted_x_shape,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
        )
273

274
275
276
        out = custom_caller(
            CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
        )
277
278
279
280
281
282
283
284
285

        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
        """
        te_cast_transpose implementation
        """
        assert CastTransposePrimitive.inner_primitive is not None
286
287
288
289
290
291
292
293
294
        casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
295
296
297
        return casted_x, casted_transposed_x, updated_amax

    @staticmethod
298
299
300
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
301
302
303
304
305
306
307
308
309
        check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
310
        transpose_axis_boundary += 1  # Plus batch dim
311
312

        out_bdims = x_bdim, x_bdim, amax_bdim
313
314
315
316
317
318
319
320
321
322
323
324
        return (
            CastTransposePrimitive.outer_primitive.bind(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
325
326

    @staticmethod
327
328
329
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
330
331
332
333
334
335
336
337
338
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
339
340
341
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
342
343
344
345
346
347
348
349
350
351
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
352
353
354
355
356
357
358
359
360
            local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary,
            )
361
362
363
364
365
366
367
368
369
370
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


371
372
373
374
375
376
377
378
379
def cast_transpose(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    static_axis_boundary: int,
    transpose_axis_boundary: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
380
381
382
383
384
385
386
387
388
389
390
    """
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
    """
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
391
392
        transpose_axis_boundary=transpose_axis_boundary,
    )
393
394
395
396
397
398


class DBiasCastTransposePrimitive(BasePrimitive):
    """
    DBias Cast Transpose Primitive
    """
399

400
401
402
403
404
405
406
407
    name = "te_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
408
409
410
411
412
413
414
415
416
417
    def abstract(
        dz_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
418
419
420
421
422
423
424
425
426
427
428
429
430
        """
        te_dbias_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
        t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
        out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

431
        dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
432
433
434
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
435
        (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes(
436
437
438
            dz_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(dz_aval.dtype),
439
440
441
442
            jax_dtype_to_te_dtype(out_dtype),
        )
        wkspace_aval = dz_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
443
444
445
446
447
448
449
450
451
452
        )

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_cast_transpose_p outer abstract
        """

453
454
455
        out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
456
457
458
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
459
460
461
    def lowering(
        ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        """
        te_dbias_cast_transpose_p lowering rules
        """
        dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
        ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
        contracted_dz_shape = (batch_size, ir_hidden_size)
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
481
482
483
484
        transposed_dz_shape = multidim_transpose(
            ir_dz_shape, static_axis_boundary, transpose_axis_boundary
        )
        dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
485
486
487
488
489
490
491
492
493
494
495
496
497
498

        wkspace_aval = ctx.avals_out[-1]

        out_types = [
            ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
            ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
        ]
        operands = [dz, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_common_wk_descriptor(
499
500
501
502
503
504
            contracted_dz_shape,
            wkspace_aval.shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
        )
505

506
507
508
        out = custom_caller(
            DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
        )
509
510
511
512

        return out

    @staticmethod
513
    def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
514
515
516
517
518
519
520
521
522
523
524
        """
        to describe implementation
        """
        assert DBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
525
526
            transpose_axis_boundary=transpose_axis_boundary,
        )
527
528
529
        return out, t_out, dbias, updated_amax

    @staticmethod
530
531
532
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
533
534
535
536
537
538
539
540
541
542
543
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DBiasCastTransposePrimitive.outer_primitive is not None
        dz, amax, scale, scale_inv = batched_args
        dz_bdim, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
544
        transpose_axis_boundary += 1  # Plus batch dim
545
546

        out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
547
548
549
550
551
552
553
554
555
556
557
558
        return (
            DBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=dz_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
559
560

    @staticmethod
561
562
563
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
564
565
566
567
568
569
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
570
571
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
572
573
574
575
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
576
577
578
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
579
580
581
582
583
584
585
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
586
587
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
588
589
590

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
591
592
593
594
595
596
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
597
598
599
600
601
602
603
604
605

        def sharded_impl(dz, amax, scale, scale_inv):
            local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
606
607
                transpose_axis_boundary=transpose_axis_boundary,
            )
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DBiasCastTransposePrimitive)


def dbias_cast_transpose(
    dz: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
625
626
    transpose_axis_boundary: int = -1,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
627
628
629
630
631
    """
    cast transpose dbias partial fusion wrapper
    Return FP8(inputs), dbias
    """
    if static_axis_boundary < 0:
632
        static_axis_boundary = -1  # means no static axes
633
634
635
636
637
638
639
640

    return DBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
641
642
        transpose_axis_boundary=transpose_axis_boundary,
    )
643
644
645
646
647
648


class DActLuDBiasCastTransposePrimitive(BasePrimitive):
    """
    DActLu DBias Cast Transpose Primitive
    """
649

650
651
652
653
654
655
656
657
    name = "te_dact_lu_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum
    impl_static_args = (5, 6, 7, 8)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
658
659
660
661
662
663
664
665
666
667
668
669
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
670
671
672
673
674
675
676
677
678
679
680
681
        """
        te_dact_lu_dbais_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
682
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary)
683
684
685
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

686
        dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
687
688
689
690
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

691
        (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
692
693
694
695
696
            x_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
        )
697
698
699
        wkspace_aval = x_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
700
701
702
703
704
705
706
707
708

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dact_lu_dbais_cast_transpose_p outer abstract
        """

709
710
711
        out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
712
713
714
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
715
716
717
718
719
720
721
722
723
724
725
726
727
    def lowering(
        ctx,
        dz,
        x,
        amax,
        scale,
        scale_inv,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum
    ):
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
        ir_hidden_szie = ir_dz_shape[-1]
        contracted_x_shape = (x_batch_size, ir_hidden_szie)

        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
753
754
755
756
        transposed_x_shape = multidim_transpose(
            x_shape, static_axis_boundary, transpose_axis_boundary
        )
        dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
757
758
759
760
761
762
763
764
765
766
767
768
769
770

        wkspace_aval = ctx.avals_out[-1]

        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        opaque = transformer_engine_jax.pack_common_wk_descriptor(
771
772
773
774
775
776
777
            contracted_x_shape,
            wkspace_aval.shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
            jax_dtype_to_te_dtype(wkspace_aval.dtype),
            act_enum,
        )
778

779
780
781
782
783
784
785
        out = custom_caller(
            DActLuDBiasCastTransposePrimitive.name,
            args,
            opaque,
            False,
            operand_output_aliases={2: 3},
        )
786
787
788
789

        return out

    @staticmethod
790
791
792
793
794
795
796
797
798
799
800
    def impl(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum,
    ):
801
802
803
804
805
806
807
808
809
810
811
812
813
        """
        to describe implementation
        """
        assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
814
815
            act_enum=act_enum,
        )
816
817
818
        return out, t_out, dbias, updated_amax

    @staticmethod
819
820
821
822
823
824
825
826
827
    def batcher(
        batched_args,
        batch_dims,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum
    ):
828
829
830
831
832
833
834
835
836
837
838
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
839
        transpose_axis_boundary += 1  # Plus batch dim
840
841

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        return (
            DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
                act_enum=act_enum,
            ),
            out_bdims,
        )
856
857

    @staticmethod
858
859
860
861
862
863
864
865
866
    def infer_sharding_from_operands(
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
867
868
869
870
871
872
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
873
874
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
875
876
877
878
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
879
880
881
882
883
884
885
886
887
    def partition(
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
888
889
890
891
892
893
894
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
895
896
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
897
898
899

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
900
901
902
903
904
905
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
906
907

        def sharded_impl(dz, x, amax, scale, scale_inv):
908
909
910
911
912
913
914
915
916
917
918
919
920
            local_out, local_t_out, local_dbias, local_amax = (
                DActLuDBiasCastTransposePrimitive.impl(
                    dz,
                    x,
                    amax,
                    scale,
                    scale_inv,
                    out_dtype=out_dtype,
                    static_axis_boundary=static_axis_boundary,
                    transpose_axis_boundary=transpose_axis_boundary,
                    act_enum=act_enum,
                )
            )
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuDBiasCastTransposePrimitive)


def dact_lu_dbias_cast_transpose(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
    transpose_axis_boundary: int = -1,
940
941
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
942
943
944
945
946
947
    """
    cast transpose dact_lu and dbias fusion wrapper
    Return FP8(dact_lu(inputs)), dbias
    ONLY support non-gated activation type
    """
    if static_axis_boundary < 0:
948
        static_axis_boundary = -1  # means no static axes
949
950
951
952
953
954
955
956
957
958
959

    act_type_id = ActivationEnum[activation_type]
    return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
960
961
        act_enum=act_type_id,
    )
962
963
964
965
966
967


class DgatedActLuCastTransposePrimitive(BasePrimitive):
    """
    Dgated ActLu Cast Transpose Primitive
    """
968

969
970
    name = "te_dgated_act_lu_cast_transpose"
    multiple_results = True
971
    impl_static_args = (5, 6, 7)  # out_dtype, static_axis_boundary, act_enum
972
973
974
975
    inner_primitive = None
    outer_primitive = None

    @staticmethod
976
977
978
979
980
981
982
983
984
985
986
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
987
988
989
990
991
992
        """
        te_dgated_act_lu_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
993
        assert x_aval.shape[-2] == 2  # Linear + GeLU
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval

    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_dz_type = ir.RankedTensorType(dz.type)
        ir_dz_shape = ir_dz_type.shape
        x_type = ir.RankedTensorType(x.type)
        x_shape = x_type.shape
        dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
        x_batch_size = reduce(operator.mul, x_shape[:-2])
        assert dz_batch_szie == x_batch_size
1024
        assert x_shape[-2] == 2  # Linear + GeLU
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        ir_hidden_szie = ir_dz_shape[-1]
        gi_hidden_size = x_shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
        ir_amax_type = ir.RankedTensorType(amax.type)
        ir_amax_dtype = ir_amax_type.element_type
        ir_amax_shape = ir_amax_type.shape
        ir_scale_shape = ir_amax_shape
        ir_scale_inv_shape = ir_amax_shape
        transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
        out_types = [
            ir.RankedTensorType.get(x_shape, ir_out_dtype),
            ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
            ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
        ]
        operands = [dz, x, amax, scale, scale_inv]
        operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
        contracted_x_shape = (x_batch_size, x_shape[-1])
        opaque = transformer_engine_jax.pack_common_descriptor(
            contracted_x_shape,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
1048
1049
            act_enum,
        )
1050

1051
1052
1053
1054
1055
1056
1057
        out = custom_caller(
            DgatedActLuCastTransposePrimitive.name,
            args,
            opaque,
            False,
            operand_output_aliases={2: 2},
        )
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074

        return out

    @staticmethod
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
        """
        to describe implementation
        """
        assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
1075
1076
            act_enum=act_enum,
        )
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
        return out, t_out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, amax_bdim
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
        return (
            DgatedActLuCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                act_enum=act_enum,
            ),
            out_bdims,
        )
1104
1105

    @staticmethod
1106
1107
1108
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos
    ):
1109
1110
1111
1112
1113
1114
1115
1116
1117
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)

    @staticmethod
1118
    def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos):
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
1138
1139
                act_enum=act_enum,
            )
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
            return local_out, local_t_out, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DgatedActLuCastTransposePrimitive)


def dgated_act_lu_cast_transpose(
1150
1151
1152
1153
1154
1155
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
1156
    static_axis_boundary: int,
1157
1158
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
    """
    cast transpose d_gated_act_lu fusion wrapper
    Return FP8(dgated_act_lu(inputs))
    """
    act_type_id = ActivationEnum[activation_type]
    return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
1172
1173
        act_enum=act_type_id,
    )