transpose.py 44.8 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
import operator

9
import jax
10
11
12
13
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
14
from jax import ffi
15

16
17
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
18
19
20
21
22
23
24
25
26
27

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import (
    check_valid_batch_dims,
    jax_dtype_to_te_dtype,
    jax_dtype_to_ir_dtype,
    te_dtype_to_jax_dtype,
    get_padded_spec,
    multidim_transpose,
28
    normalize_axis_boundary,
29
    is_ffi_enabled,
30
31
)
from .activation import ActivationEnum
32
33
from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
34
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
35
36


37
38
39
40
41
42
43
__all__ = [
    "transpose",
    "cast_transpose",
    "dbias_cast_transpose",
    "dact_lu_dbias_cast_transpose",
    "dgated_act_lu_cast_transpose",
]
44
45


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def _jax_transpose(inputs, static_axis_boundary, transpose_axis_boundary):
    """
    JAX native transpose implementation
    """
    axes = multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary)
    return jnp.transpose(inputs, axes=axes)


def _jax_cast_transpose(
    inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
):
    """
    JAX native cast_transpose implementation
    """
    casted_output, updated_amax = _jax_cast_fp8(inputs, scale, amax, out_dtype=out_dtype)
    casted_transposed_output = _jax_transpose(
        casted_output, static_axis_boundary, transpose_axis_boundary
    )
    return casted_output, casted_transposed_output, updated_amax


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def _jax_dbias_cast_transpose(
    dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
):
    """
    JAX native dbias_cast_transpose implementation
    """
    casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose(
        dz,
        scale,
        amax,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
    )
    dbias = jnp.sum(
        dz,
        axis=tuple(
            range(
                transpose_axis_boundary
                if transpose_axis_boundary > 0
                else transpose_axis_boundary + dz.ndim
            )
        ),
        keepdims=False,
    )
    dbias = dbias.ravel()  # C++ function returns an 1D array for dbias
    return casted_dz, cast_transposed_dz, dbias, updated_amax


96
97
98
99
class TransposePrimitive(BasePrimitive):
    """
    Transpose Primitive
    """
100

101
102
103
104
105
106
107
108
109
110
111
    name = "te_transpose"
    multiple_results = False
    impl_static_args = (1, 2)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def abstract(x_aval, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose abstract
        """
112
113
114
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
115
116
117
118
119
120
121
122
123
124
125
126
        xt_aval = x_aval.update(shape=transposed_x_shape, dtype=x_aval.dtype)

        return xt_aval

    @staticmethod
    def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary):
        """
        _transpose cuda lowering
        """

        x_aval = ctx.avals_in[0]
        assert x_aval.dtype in [
127
128
129
130
131
            jnp.float32,
            jnp.float16,
            jnp.bfloat16,
            jnp.float8_e4m3fn,
            jnp.float8_e5m2,
132
133
        ]

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if is_ffi_enabled():
            name = "te_transpose_ffi"
            out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary)
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype)
            if static_axis_boundary >= 0:
                for i in range(static_axis_boundary + 1):
                    assert ir_x_shape[i] == 1

            transposed_x_shape = multidim_transpose(
                ir_x_shape, static_axis_boundary, transpose_axis_boundary
            )

            out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)]
            operands = [x]
            operand_shapes = [ir_x_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            te_dtype = jax_dtype_to_te_dtype(x_aval.dtype)
            contracted_x_shape = (
                reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
            )
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape, te_dtype, te_dtype
            )

            out = custom_caller(TransposePrimitive.name, args, opaque, False)
164

165
        return out
166
167
168
169
170
171
172

    @staticmethod
    def impl(x, static_axis_boundary, transpose_axis_boundary):
        """
        tcast_transpose implementation
        """
        assert TransposePrimitive.inner_primitive is not None
173
174
175
176
177
        transposed_x = TransposePrimitive.inner_primitive.bind(
            x,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
178
179
180
181
182
183
184
185
        return transposed_x

    @staticmethod
    def batcher(batched_args, batch_dims, *, static_axis_boundary, transpose_axis_boundary):
        check_valid_batch_dims(batch_dims)
        assert TransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

186
187
        (x,) = batched_args
        (x_bdim,) = batch_dims
188
189
190

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
191
        transpose_axis_boundary += 1  # Plus batch dim
192
193

        out_bdims = x_bdim
194
195
196
197
198
199
        return (
            TransposePrimitive.outer_primitive.bind(
                x, static_axis_boundary=x_bdim, transpose_axis_boundary=transpose_axis_boundary
            ),
            out_bdims,
        )
200
201

    @staticmethod
202
203
204
    def infer_sharding_from_operands(
        static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        return transposed_x_sharding

    @staticmethod
    def partition(static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos):
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = transposed_x_sharding

220
221
222
223
224
        impl = partial(
            TransposePrimitive.impl,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
225
226
227
228
229
230
231

        return mesh, impl, out_shardings, arg_shardings


register_primitive(TransposePrimitive)


232
233
234
def transpose(
    x: jnp.ndarray, static_axis_boundary: int, transpose_axis_boundary: int
) -> jnp.ndarray:
235
236
237
    """
    transpose wrapper
    """
238
239
    if not TransposePrimitive.enabled():
        return _jax_transpose(x, static_axis_boundary, transpose_axis_boundary)
240
241
242
243
244
    return TransposePrimitive.outer_primitive.bind(
        x,
        static_axis_boundary=static_axis_boundary,
        transpose_axis_boundary=transpose_axis_boundary,
    )
245
246
247
248
249
250


class CastTransposePrimitive(BasePrimitive):
    """
    Cast Transpose Primitive
    """
251

252
253
254
255
256
257
258
    name = "te_cast_transpose"
    multiple_results = True
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
259
260
261
262
263
264
265
266
267
268
    def abstract(
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
269
270
271
272
273
274
275
276
277
        """
        te_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(x_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32

278
279
280
        transposed_x_shape = multidim_transpose(
            x_aval.shape, static_axis_boundary, transpose_axis_boundary
        )
281
282
283
284
285
286
287
288

        casted_x_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
        casted_xt_aval = x_aval.update(shape=transposed_x_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

        return casted_x_aval, casted_xt_aval, updated_amax_aval

    @staticmethod
289
290
291
    def lowering(
        ctx, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
292
293
294
295
296
297
298
299
        """
        te_cast_transpose_p lowering rules
        """
        x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        if is_ffi_enabled():
            name = "te_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
                ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
            )
        else:
            ir_x_type = ir.RankedTensorType(x.type)
            ir_x_shape = ir_x_type.shape
            if static_axis_boundary >= 0:
                for i in range(static_axis_boundary + 1):
                    assert ir_x_shape[i] == 1
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape

            transposed_x_shape = multidim_transpose(
                ir_x_shape, static_axis_boundary, transpose_axis_boundary
            )

            out_types = [
                ir.RankedTensorType.get(ir_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [x, amax, scale, scale_inv]
            operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            contracted_x_shape = (
                reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]),
                reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]),
            )
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape,
                jax_dtype_to_te_dtype(x_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
            )
            out = custom_caller(
                CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
            )
343
344
345
346
347
348
349
350
        return out

    @staticmethod
    def impl(x, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
        """
        te_cast_transpose implementation
        """
        assert CastTransposePrimitive.inner_primitive is not None
351
352
353
354
355
356
357
358
359
        casted_x, casted_transposed_x, updated_amax = CastTransposePrimitive.inner_primitive.bind(
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=transpose_axis_boundary,
        )
360
361
362
        return casted_x, casted_transposed_x, updated_amax

    @staticmethod
363
364
365
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
366
367
368
369
370
371
372
373
374
        check_valid_batch_dims(batch_dims)
        assert CastTransposePrimitive.outer_primitive is not None
        assert static_axis_boundary < 0

        x, amax, scale, scale_inv = batched_args
        x_bdim, amax_bdim, *_ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1)
375
        transpose_axis_boundary += 1  # Plus batch dim
376
377

        out_bdims = x_bdim, x_bdim, amax_bdim
378
379
380
381
382
383
384
385
386
387
388
389
        return (
            CastTransposePrimitive.outer_primitive.bind(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
390
391

    @staticmethod
392
393
394
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
395
396
397
398
399
400
401
402
403
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

    @staticmethod
404
405
406
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
407
408
409
410
411
412
413
414
415
416
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(x, amax, scale, scale_inv):
417
418
419
420
421
422
423
424
425
            local_cx, local_cxt, local_updated_amax = CastTransposePrimitive.impl(
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
                transpose_axis_boundary=transpose_axis_boundary,
            )
426
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
427
428
429
430
431
432
433
434
435

            return local_cx, local_cxt, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(CastTransposePrimitive)


436
437
438
439
440
441
442
443
444
def cast_transpose(
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: jnp.dtype,
    static_axis_boundary: int,
    transpose_axis_boundary: int,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
445
446
447
448
    """
    cast transpose wrapper
    Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
    """
449
450
    if not CastTransposePrimitive.enabled():
        return _jax_cast_transpose(
451
            x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary
452
        )
453
454
455
456
457
458
459
    return CastTransposePrimitive.outer_primitive.bind(
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
460
461
        transpose_axis_boundary=transpose_axis_boundary,
    )
462
463
464
465
466
467


class DBiasCastTransposePrimitive(BasePrimitive):
    """
    DBias Cast Transpose Primitive
    """
468

469
470
471
472
473
474
475
476
    name = "te_dbias_cast_transpose"
    multiple_results = True
    # out_dtype, static_axis_boundary, transpose_axis_boundary
    impl_static_args = (4, 5, 6)
    inner_primitive = None
    outer_primitive = None

    @staticmethod
477
478
479
480
481
482
483
484
485
486
    def abstract(
        dz_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        transpose_axis_boundary
    ):
487
488
489
490
491
492
493
494
495
496
497
498
499
        """
        te_dbias_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        gi_hidden_size = reduce(operator.mul, dz_aval.shape[transpose_axis_boundary:])
        t_shape = multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary)
        out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

500
        dbias_shape = (*dz_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
501
502
503
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
504
        (wkspace_info,) = transformer_engine_jax.get_dbias_ct_workspace_sizes(
505
506
507
            dz_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(dz_aval.dtype),
508
509
510
511
            jax_dtype_to_te_dtype(out_dtype),
        )
        wkspace_aval = dz_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
512
513
514
515
516
517
518
519
520
521
        )

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dbias_cast_transpose_p outer abstract
        """

522
523
524
        out, t_out, dbias, updated_amax_aval, _ = DBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
525
526
527
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
528
529
530
    def lowering(
        ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
531
532
533
534
535
536
537
538
        """
        te_dbias_cast_transpose_p lowering rules
        """
        dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        if is_ffi_enabled():
            name = "te_dbias_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})(
                ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
            )
        else:
            ir_dz_type = ir.RankedTensorType(dz.type)
            ir_dz_shape = ir_dz_type.shape
            batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary])
            ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:])
            contracted_dz_shape = (batch_size, ir_hidden_size)
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape
            transposed_dz_shape = multidim_transpose(
                ir_dz_shape, static_axis_boundary, transpose_axis_boundary
            )
            dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size)
560

561
            wkspace_aval = ctx.avals_out[-1]
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
            out_types = [
                ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype),
                ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
                ir.RankedTensorType.get(
                    wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
                ),
            ]
            operands = [dz, amax, scale, scale_inv]
            operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
            opaque = transformer_engine_jax.pack_common_wk_descriptor(
                contracted_dz_shape,
                wkspace_aval.shape,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
            )
582

583
584
585
            out = custom_caller(
                DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3}
            )
586
587
588
589

        return out

    @staticmethod
590
    def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis_boundary):
591
592
593
594
595
596
597
598
599
600
601
        """
        to describe implementation
        """
        assert DBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
602
603
            transpose_axis_boundary=transpose_axis_boundary,
        )
604
605
606
        return out, t_out, dbias, updated_amax

    @staticmethod
607
608
609
    def batcher(
        batched_args, batch_dims, *, out_dtype, static_axis_boundary, transpose_axis_boundary
    ):
610
611
612
613
614
615
616
617
618
619
620
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DBiasCastTransposePrimitive.outer_primitive is not None
        dz, amax, scale, scale_inv = batched_args
        dz_bdim, amax_bdim, _, _ = batch_dims

        # Minus batch dim.
        transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1)
621
        transpose_axis_boundary += 1  # Plus batch dim
622
623

        out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim
624
625
626
627
628
629
630
631
632
633
634
635
        return (
            DBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=dz_bdim,
                transpose_axis_boundary=transpose_axis_boundary,
            ),
            out_bdims,
        )
636
637

    @staticmethod
638
639
640
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
641
642
643
644
645
646
        del out_dtype, result_infos
        x_spec = get_padded_spec(arg_infos[0])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
647
648
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
649
650
651
652
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
653
654
655
    def partition(
        out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, result_infos
    ):
656
657
658
659
660
661
662
        del result_infos
        x_spec = get_padded_spec(arg_infos[0])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
663
664
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
665
666
667

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
668
669
670
671
672
673
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
674
675
676
677
678
679
680
681
682

        def sharded_impl(dz, amax, scale, scale_inv):
            local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl(
                dz,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
683
684
                transpose_axis_boundary=transpose_axis_boundary,
            )
685
686
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DBiasCastTransposePrimitive)


def dbias_cast_transpose(
    dz: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
702
703
    transpose_axis_boundary: int = -1,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
704
705
706
707
708
    """
    cast transpose dbias partial fusion wrapper
    Return FP8(inputs), dbias
    """
    if static_axis_boundary < 0:
709
        static_axis_boundary = -1  # means no static axes
710

711
    if not DBiasCastTransposePrimitive.enabled():
712
713
        return _jax_dbias_cast_transpose(
            dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
714
715
        )

716
717
718
719
720
721
722
    return DBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
723
724
        transpose_axis_boundary=transpose_axis_boundary,
    )
725
726
727
728
729
730


class DActLuDBiasCastTransposePrimitive(BasePrimitive):
    """
    DActLu DBias Cast Transpose Primitive
    """
731

732
733
    name = "te_dact_lu_dbias_cast_transpose"
    multiple_results = True
734
735
    # out_dtype, static_axis_boundary, act_enum
    impl_static_args = (5, 6, 7)
736
737
738
739
    inner_primitive = None
    outer_primitive = None

    @staticmethod
740
741
742
743
744
745
746
747
748
749
750
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
751
752
753
754
755
756
757
758
759
760
761
762
        """
        te_dact_lu_dbais_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
763
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
764
765
766
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)

767
        dbias_shape = (*x_aval.shape[: static_axis_boundary + 1], gi_hidden_size)
768
769
770
771
        dbias = dz_aval.update(shape=dbias_shape, dtype=dtype)

        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)

772
        (wkspace_info,) = transformer_engine_jax.get_dact_dbias_ct_workspace_sizes(
773
774
775
776
777
            x_aval.size // gi_hidden_size,
            gi_hidden_size,
            jax_dtype_to_te_dtype(x_aval.dtype),
            jax_dtype_to_te_dtype(out_dtype),
        )
778
779
780
        wkspace_aval = x_aval.update(
            shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
        )
781
782
783
784
785
786
787
788
789

        return out, t_out, dbias, updated_amax_aval, wkspace_aval

    @staticmethod
    def outer_abstract(*args, **kwargs):
        """
        te_dact_lu_dbais_cast_transpose_p outer abstract
        """

790
791
792
        out, t_out, dbias, updated_amax_aval, _ = DActLuDBiasCastTransposePrimitive.abstract(
            *args, **kwargs
        )
793
794
795
        return out, t_out, dbias, updated_amax_aval

    @staticmethod
796
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
797
798
799
800
801
802
803
804
805
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
        if is_ffi_enabled():
            name = "te_dact_lu_dbias_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})(
                ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
            )
        else:
            ir_dz_type = ir.RankedTensorType(dz.type)
            ir_dz_shape = ir_dz_type.shape
            x_type = ir.RankedTensorType(x.type)
            x_shape = x_type.shape
            dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
            x_batch_size = reduce(operator.mul, x_shape[:-2])
            assert dz_batch_szie == x_batch_size
            ir_hidden_szie = ir_dz_shape[-1]
            contracted_x_shape = (x_batch_size, ir_hidden_szie)
821

822
823
824
825
826
827
828
829
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape
            transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
            dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie)
830

831
            wkspace_aval = ctx.avals_out[-1]
832

833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
            out_types = [
                ir.RankedTensorType.get(x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
                ir.RankedTensorType.get(
                    wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
                ),
            ]
            operands = [dz, x, amax, scale, scale_inv]
            operand_shapes = [
                ir_dz_shape,
                x_shape,
                ir_amax_shape,
                ir_scale_shape,
                ir_scale_inv_shape,
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
            opaque = transformer_engine_jax.pack_common_wk_descriptor(
                contracted_x_shape,
                wkspace_aval.shape,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                jax_dtype_to_te_dtype(wkspace_aval.dtype),
                act_enum,
            )

            out = custom_caller(
                DActLuDBiasCastTransposePrimitive.name,
                args,
                opaque,
                False,
                operand_output_aliases={2: 3},
            )
867
868
869
870

        return out

    @staticmethod
871
872
873
874
875
876
877
878
879
880
    def impl(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype,
        static_axis_boundary,
        act_enum,
    ):
881
882
883
884
885
886
887
888
889
890
891
892
        """
        to describe implementation
        """
        assert DActLuDBiasCastTransposePrimitive.inner_primitive is not None
        out, t_out, dbias, updated_amax, _ = DActLuDBiasCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
893
894
            act_enum=act_enum,
        )
895
896
897
        return out, t_out, dbias, updated_amax

    @staticmethod
898
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
899
900
901
902
903
904
905
906
907
908
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DActLuDBiasCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim
909
910
911
912
913
914
915
916
917
918
919
920
921
        return (
            DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                act_enum=act_enum,
            ),
            out_bdims,
        )
922
923

    @staticmethod
924
925
926
927
928
929
930
931
    def infer_sharding_from_operands(
        out_dtype,
        static_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
932
933
934
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
935
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
936
937
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        dbias_shaprding = NamedSharding(
938
939
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
940
941
942
943
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding)

    @staticmethod
944
945
946
947
948
949
950
951
    def partition(
        out_dtype,
        static_axis_boundary,
        act_enum,
        mesh,
        arg_infos,
        result_infos,
    ):
952
953
954
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
955
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
956
957
958
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        dbias_shaprding = NamedSharding(
959
960
            mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1])
        )
961
962
963

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
964
965
966
967
968
969
        out_shardings = (
            casted_x_sharding,
            casted_transposed_x_sharding,
            dbias_shaprding,
            amax_sharding,
        )
970
971

        def sharded_impl(dz, x, amax, scale, scale_inv):
972
973
974
975
976
977
978
979
980
981
982
983
            local_out, local_t_out, local_dbias, local_amax = (
                DActLuDBiasCastTransposePrimitive.impl(
                    dz,
                    x,
                    amax,
                    scale,
                    scale_inv,
                    out_dtype=out_dtype,
                    static_axis_boundary=static_axis_boundary,
                    act_enum=act_enum,
                )
            )
984
985
            global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
            return local_out, local_t_out, global_dbias, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DActLuDBiasCastTransposePrimitive)


def dact_lu_dbias_cast_transpose(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
    static_axis_boundary: int,
1002
1003
    activation_type: Sequence[Union[str, Callable]] = ("gelu",),
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1004
1005
1006
1007
1008
1009
    """
    cast transpose dact_lu and dbias fusion wrapper
    Return FP8(dact_lu(inputs)), dbias
    ONLY support non-gated activation type
    """
    if static_axis_boundary < 0:
1010
        static_axis_boundary = -1  # means no static axes
1011

1012
1013
1014
    if not DActLuDBiasCastTransposePrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
        (dx,) = vjp_func(dz)
1015
1016
1017
        transpose_axis_boundary = -2
        return _jax_dbias_cast_transpose(
            dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary
1018
1019
        )

1020
1021
1022
1023
1024
1025
1026
1027
1028
    act_type_id = ActivationEnum[activation_type]
    return DActLuDBiasCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
1029
1030
        act_enum=act_type_id,
    )
1031
1032
1033
1034
1035
1036


class DgatedActLuCastTransposePrimitive(BasePrimitive):
    """
    Dgated ActLu Cast Transpose Primitive
    """
1037

1038
1039
    name = "te_dgated_act_lu_cast_transpose"
    multiple_results = True
1040
    impl_static_args = (5, 6, 7)  # out_dtype, static_axis_boundary, act_enum
1041
1042
1043
1044
    inner_primitive = None
    outer_primitive = None

    @staticmethod
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
    def abstract(
        dz_aval,
        x_aval,
        amax_aval,
        scale_aval,
        scale_inv_aval,
        *,
        out_dtype,
        static_axis_boundary,
        act_enum
    ):  # pylint: disable=unused-argument
1056
1057
1058
1059
1060
1061
        """
        te_dgated_act_lu_cast_transpose_p abstract
        """
        dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dtype
1062
        assert x_aval.shape[-2] == 2  # Linear + GeLU
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
        ir_hidden_szie = dz_aval.shape[-1]
        gi_hidden_size = x_aval.shape[-1]
        assert ir_hidden_szie == gi_hidden_size
        t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2)
        out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype)
        t_out = dz_aval.update(shape=t_shape, dtype=out_dtype)
        updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype)
        return out, t_out, updated_amax_aval

    @staticmethod
    def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum):
        """
        te_dgated_act_lu_cast_transpose_p lowering rules
        """
        dz_aval, x_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
        assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
        assert x_aval.dtype == dz_aval.dtype
        assert amax_aval.dtype == jnp.float32
        assert scale_aval.dtype == jnp.float32
        assert scale_inv_aval.dtype == jnp.float32
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
        if is_ffi_enabled():
            name = "te_dgated_act_lu_cast_transpose_ffi"
            out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
                ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum)
            )
        else:
            ir_dz_type = ir.RankedTensorType(dz.type)
            ir_dz_shape = ir_dz_type.shape
            x_type = ir.RankedTensorType(x.type)
            x_shape = x_type.shape
            dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1])
            x_batch_size = reduce(operator.mul, x_shape[:-2])
            assert dz_batch_szie == x_batch_size
            assert x_shape[-2] == 2  # Linear + GeLU
            ir_hidden_szie = ir_dz_shape[-1]
            gi_hidden_size = x_shape[-1]
            assert ir_hidden_szie == gi_hidden_size
            ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
            ir_amax_type = ir.RankedTensorType(amax.type)
            ir_amax_dtype = ir_amax_type.element_type
            ir_amax_shape = ir_amax_type.shape
            ir_scale_shape = ir_amax_shape
            ir_scale_inv_shape = ir_amax_shape
            transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2)
            out_types = [
                ir.RankedTensorType.get(x_shape, ir_out_dtype),
                ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype),
                ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
            ]
            operands = [dz, x, amax, scale, scale_inv]
            operand_shapes = [
                ir_dz_shape,
                x_shape,
                ir_amax_shape,
                ir_scale_shape,
                ir_scale_inv_shape,
            ]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
            contracted_x_shape = (x_batch_size, x_shape[-1])
            opaque = transformer_engine_jax.pack_common_descriptor(
                contracted_x_shape,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                jax_dtype_to_te_dtype(out_dtype),
                act_enum,
            )
1131

1132
1133
1134
1135
1136
1137
1138
            out = custom_caller(
                DgatedActLuCastTransposePrimitive.name,
                args,
                opaque,
                False,
                operand_output_aliases={2: 2},
            )
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155

        return out

    @staticmethod
    def impl(dz, x, amax, scale, scale_inv, out_dtype, static_axis_boundary, act_enum):
        """
        to describe implementation
        """
        assert DgatedActLuCastTransposePrimitive.inner_primitive is not None
        out, t_out, updated_amax = DgatedActLuCastTransposePrimitive.inner_primitive.bind(
            dz,
            x,
            amax,
            scale,
            scale_inv,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
1156
1157
            act_enum=act_enum,
        )
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
        return out, t_out, updated_amax

    @staticmethod
    def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum):
        """
        to describe batch rules for vmap
        """
        del static_axis_boundary
        check_valid_batch_dims(batch_dims)
        assert DgatedActLuCastTransposePrimitive.outer_primitive is not None
        dz, x, amax, scale, scale_inv = batched_args
        x_bdim, _, amax_bdim, _, _ = batch_dims

        out_bdims = x_bdim, x_bdim, amax_bdim
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
        return (
            DgatedActLuCastTransposePrimitive.outer_primitive.bind(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=x_bdim,
                act_enum=act_enum,
            ),
            out_bdims,
        )
1185
1186

    @staticmethod
1187
1188
1189
    def infer_sharding_from_operands(
        out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos
    ):
1190
1191
1192
1193
1194
1195
1196
1197
1198
        del out_dtype, result_infos, act_enum
        x_spec = get_padded_spec(arg_infos[1])
        out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))
        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        return (out_sharding, tranposed_out_sharding, amax_sharding)

    @staticmethod
1199
    def partition(out_dtype, static_axis_boundary, act_enum, mesh, arg_infos, result_infos):
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        del result_infos
        x_spec = get_padded_spec(arg_infos[1])
        casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec))
        xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2)
        casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec))

        amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2])))
        arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
        out_shardings = (casted_x_sharding, casted_transposed_x_sharding, amax_sharding)

        def sharded_impl(dz, x, amax, scale, scale_inv):
            local_out, local_t_out, local_amax = DgatedActLuCastTransposePrimitive.impl(
                dz,
                x,
                amax,
                scale,
                scale_inv,
                out_dtype=out_dtype,
                static_axis_boundary=static_axis_boundary,
1219
1220
                act_enum=act_enum,
            )
1221
            global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
1222
1223
1224
1225
1226
1227
1228
1229
1230
            return local_out, local_t_out, global_updated_amax

        return mesh, sharded_impl, out_shardings, arg_shardings


register_primitive(DgatedActLuCastTransposePrimitive)


def dgated_act_lu_cast_transpose(
1231
1232
1233
1234
1235
1236
    dz: jnp.ndarray,
    x: jnp.ndarray,
    amax: jnp.ndarray,
    scale: jnp.ndarray,
    scale_inv: jnp.ndarray,
    out_dtype: TEDType,
1237
    static_axis_boundary: int,
1238
1239
    activation_type: Sequence[Union[str, Callable]],
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
1240
1241
1242
1243
1244
    """
    cast transpose d_gated_act_lu fusion wrapper
    Return FP8(dgated_act_lu(inputs))
    """
    act_type_id = ActivationEnum[activation_type]
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
    if not DgatedActLuCastTransposePrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x)
        (dx,) = vjp_func(dz)
        return _jax_cast_transpose(
            dx,
            scale,
            amax,
            out_dtype=out_dtype,
            static_axis_boundary=static_axis_boundary,
            transpose_axis_boundary=-2,
        )
1256
1257
1258
1259
1260
1261
1262
1263
    return DgatedActLuCastTransposePrimitive.outer_primitive.bind(
        dz,
        x,
        amax,
        scale,
        scale_inv,
        out_dtype=out_dtype,
        static_axis_boundary=static_axis_boundary,
1264
1265
        act_enum=act_type_id,
    )