"vscode:/vscode.git/clone" did not exist on "dcc92d0ab6c4ce022162a23566d44f673251eee4"
softmax.py 32.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for softmax"""
from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes
13
14
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
15
from jax.extend import ffi
16
17
18
19
20

from transformer_engine import transformer_engine_jax

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
21
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
22
23
24
from ..softmax import SoftmaxType


25
26
27
28
29
30
31
32
33
34
35
__all__ = [
    "scaled_softmax_fwd",
    "scaled_softmax_bwd",
    "scaled_masked_softmax_fwd",
    "scaled_masked_softmax_bwd",
    "scaled_upper_triang_masked_softmax_fwd",
    "scaled_upper_triang_masked_softmax_bwd",
    "is_softmax_kernel_available",
]


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
    return jax.nn.softmax(scale_factor * logits)


def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
    if mask is not None:
        logits += jax.lax.select(
            mask > 0,
            jnp.full(mask.shape, -1e10).astype(logits.dtype),
            jnp.full(mask.shape, 0.0).astype(logits.dtype),
        )
    return jax.nn.softmax(logits * scale_factor)


def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
    mask = 1 - jnp.tril(jnp.ones_like(logits))
    logits += jax.lax.select(
        mask > 0,
        jnp.full(mask.shape, -1e10).astype(logits.dtype),
        jnp.full(mask.shape, 0.0).astype(logits.dtype),
    )
    return jax.nn.softmax(logits * scale_factor)


60
61
62
63
64
65
66
67
def is_softmax_kernel_available(
    softmax_type: SoftmaxType,
    batch: int,
    heads: int,
    q_seqlen: int,
    k_seqlen: int,
    dtype: jnp.dtype,
):
68
69
    """check softmax available"""
    if softmax_type is SoftmaxType.SCALED:
70
71
72
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
73
    if softmax_type is SoftmaxType.SCALED_MASKED:
74
75
76
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
77
78
    if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
79
80
            batch, heads, q_seqlen, k_seqlen, dtype
        )
81
82
83
84
85
86
87
88

    raise NotImplementedError


class SoftmaxPrimitive(BasePrimitive):
    """
    Softmax Primitive
    """
89

90
91
92
93
94
    max_k_seqlen_supported = 16384
    name = "te_softmax_internal_placeholder"

    @staticmethod
    @abstractmethod
95
96
97
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
98
99
100
101
102
103
104
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError

    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
105
        threads_per_block = 128  # Depends on the kernel implmentation
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block

    @staticmethod
    def forward_abstract(logits_aval, scale_factor):
        """
        softmax_forward abstract
        """
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

129
        out_aval = logits_aval
130
131
132
133
134
135
136
        return out_aval

    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        if is_ffi_enabled():
            ffi_name = name + "_ffi"
            out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
        else:
            (i_aval,) = ctx.avals_in
            i_type = ir.RankedTensorType(logits.type)
            i_shape = i_type.shape
            # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
            batch = reduce(operator.mul, i_shape[:-3])
            pad_batch = batch
            heads = i_shape[-3]
            q_seqlen = i_shape[-2]
            k_seqlen = i_shape[-1]

            out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
            operands = [logits]
            operand_shapes = [i_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_softmax_descriptor(
                batch,
                pad_batch,
                heads,
                q_seqlen,
                k_seqlen,
                jax_dtype_to_te_dtype(i_aval.dtype),
                scale_factor,
            )
165

166
            out = custom_caller(name, args, opaque, False)
167

168
        return out
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output

    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
185
186
        (logits,) = batched_args
        (logits_bdim,) = batch_dims
187
188
189
190
191
192
193
194
195

        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims

    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
196
        del scale_factor, result_infos  # Unused.
197
198
199
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
200
201
202
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
203
204
205
206
207
208
209
210
211
212
213
214
215
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        return out_sharding

    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward partitioning
        """
        del result_infos
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
216
217
218
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
219
220
221
222
223
224
225
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings

    @staticmethod
226
227
228
    def backward_abstract(
        dz_aval, softmax_out_aval, scale_factor=None
    ):  # pylint: disable=unused-argument
229
230
231
232
233
234
235
236
237
238
239
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]

        assert dz_aval.shape == softmax_out_aval.shape

240
        dx_aval = dz_aval
241
242
243
244
245
246
247
        return dx_aval

    @staticmethod
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
        """
        softmax_backward lowering rules
        """
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        if is_ffi_enabled():
            ffi_name = name + "_ffi"
            out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
        else:
            dz_aval, _ = ctx.avals_in

            dz_type = ir.RankedTensorType(dz.type)
            dz_shape = dz_type.shape

            # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
            batch = reduce(operator.mul, dz_shape[:-3])
            pad_batch = batch  # unused
            heads = dz_shape[-3]
            q_seqlen = dz_shape[-2]
            k_seqlen = dz_shape[-1]

            softmax_out_type = ir.RankedTensorType(softmax_out.type)
            softmax_out_shape = softmax_out_type.shape

            out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
            operands = [dz, softmax_out]
            operand_shapes = [dz_shape, softmax_out_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_softmax_descriptor(
                batch,
                pad_batch,
                heads,
                q_seqlen,
                k_seqlen,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                scale_factor,
            )
281

282
            out = custom_caller(name, args, opaque, False)
283

284
        return out
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

    @staticmethod
    def backward_impl(primitive, dz, softmax_out, scale_factor):
        """
        softmax_backward implementation
        """
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx

    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims

        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward infer_sharding_from_operands
        """
312
        del scale_factor, result_infos  # Unused.
313
314
315
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
316
317
318
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        return dx_sharding

    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
334
335
336
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
353

354
355
    name = "te_scaled_softmax_forward"
    multiple_results = False
356
    impl_static_args = (1,)  # scale_factor
357
358
359
360
    inner_primitive = None
    outer_primitive = None

    @staticmethod
361
362
363
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
364
365
366
367
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
368
369
370
371
372
373
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
374
375
376
377
378
379
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
380
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
381
382
383
384
385
386
387
388
389
390
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
391
392
393
        return SoftmaxPrimitive.forward_lowering(
            ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
394
395
396

    @staticmethod
    def impl(logits, scale_factor):
397
398
399
        return SoftmaxPrimitive.forward_impl(
            ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
400
401
402
403

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
404
405
406
407
408
409
        return SoftmaxPrimitive.forward_batcher(
            ScaledSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
410
411
412
413

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
414
415
            scale_factor, mesh, arg_infos, result_infos
        )
416
417
418

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
419
420
421
        return ScaledSoftmaxFwdPrimitive.forward_partition(
            ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
422
423
424
425
426
427
428
429
430
431


register_primitive(ScaledSoftmaxFwdPrimitive)


def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
432
433
    if not ScaledSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_softmax(logits, scale_factor)
434
435
436
437
438
439
440
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)


class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Bwd Primitive
    """
441

442
443
    name = "te_scaled_softmax_backward"
    multiple_results = False
444
    impl_static_args = (2,)  # scale_factor
445
446
447
448
    inner_primitive = None
    outer_primitive = None

    @staticmethod
449
450
451
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
452
        """Check Softmax kernel availability based on size"""
453
454
455
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
456
457
458
459
460
461
462
463
464
465
466
467
468

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, scale_factor):
        """
        te_scaled_softmax_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
469
470
471
        out = SoftmaxPrimitive.backward_lowering(
            ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
472
473
474
475
476

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
477
478
479
        return SoftmaxPrimitive.backward_impl(
            ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor
        )
480
481
482
483

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
484
485
486
487
488
489
        return SoftmaxPrimitive.backward_batcher(
            ScaledSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
490
491
492
493

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
494
495
            scale_factor, mesh, arg_infos, result_infos
        )
496
497
498

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
499
500
501
        return ScaledSoftmaxBwdPrimitive.backward_partition(
            ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
502
503
504
505
506


register_primitive(ScaledSoftmaxBwdPrimitive)


507
def scaled_softmax_bwd(
508
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
509
) -> jnp.ndarray:
510
511
512
513
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
514
515
516
517
    if not ScaledSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
        return vjp_func(dz)[0]

518
519
520
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
521
522
523
524
525
526


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
527

528
529
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
530
    impl_static_args = (2,)  # scale_factor
531
532
533
534
    inner_primitive = None
    outer_primitive = None

    @staticmethod
535
536
537
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
538
539
540
541
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
542
543
544
545
546
547
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
548
549
550
551
552
553
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
554
    def abstract(logits_aval, mask_aval, scale_factor):  # pylint: disable=unused-argument
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
        """
        te_scaled_masked_softmax_forward abstract
        """

        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
        ]
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
576
577
        assert pad_batch in (1, batch)  # 1 means broadcast
        assert mask_shape[-3] == 1  # 1 means broadcast
578
579
580
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

581
        out_aval = logits_aval
582
583
584
585
586
587
588
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        if is_ffi_enabled():
            ffi_name = "te_scaled_masked_softmax_forward_ffi"
            out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
        else:
            logits_aval, _ = ctx.avals_in
            i_type = ir.RankedTensorType(logits.type)
            i_shape = i_type.shape
            # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
            batch = reduce(operator.mul, i_shape[:-3])
            heads = i_shape[-3]
            q_seqlen = i_shape[-2]
            k_seqlen = i_shape[-1]

            mask_type = ir.RankedTensorType(mask.type)
            mask_shape = mask_type.shape
            pad_batch = reduce(operator.mul, mask_shape[:-3])

            out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
            operands = [logits, mask]
            operand_shapes = [i_shape, mask_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_softmax_descriptor(
                batch,
                pad_batch,
                heads,
                q_seqlen,
                k_seqlen,
                jax_dtype_to_te_dtype(logits_aval.dtype),
                scale_factor,
            )
620

621
            out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
622

623
        return out
624
625
626
627

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
628
629
630
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(
            logits, mask, scale_factor=scale_factor
        )
631
632
633
634
635
636
637
638
639
640
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
641
642
643
644
645
646
        return (
            ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
                logits, mask, scale_factor=scale_factor
            ),
            out_bdims,
        )
647
648
649
650

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
651
652
            scale_factor, mesh, arg_infos, result_infos
        )
653
654
655
656

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
657
658
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
659
660
661
662
663


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


664
665
666
def scaled_masked_softmax_fwd(
    logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
667
668
669
670
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
671
672
    if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_masked_softmax(logits, mask, scale_factor)
673
674
675
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, mask, scale_factor=scale_factor
    )
676
677
678
679
680
681


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
682

683
684
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
685
    impl_static_args = (2,)  # scale_factor
686
687
688
689
    inner_primitive = None
    outer_primitive = None

    @staticmethod
690
691
692
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
693
        """Check Softmax kernel availability based on size"""
694
695
696
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
697
698
699
700
701
702
703
704
705
706
707
708
709

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
710
711
712
        out = SoftmaxPrimitive.backward_lowering(
            ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
713
714
715
716
717

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
718
719
720
721
722
723
        return SoftmaxPrimitive.backward_impl(
            ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
724
725
726
727

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
728
729
730
731
732
733
        return SoftmaxPrimitive.backward_batcher(
            ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
734
735
736
737

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
738
739
            scale_factor, mesh, arg_infos, result_infos
        )
740
741
742
743

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
744
745
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
746
747
748
749
750


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


751
def scaled_masked_softmax_bwd(
752
753
754
755
756
    dz: jnp.ndarray,
    softmax_out: jnp.ndarray,
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
757
) -> jnp.ndarray:
758
759
760
761
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
762
763
764
765
766
    if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
        )
        return vjp_func(dz)[0]
767
768
769
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
770
771
772
773
774
775


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
776

777
778
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
779
    impl_static_args = (1,)  # scale_factor
780
781
782
783
    inner_primitive = None
    outer_primitive = None

    @staticmethod
784
785
786
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
787
788
789
790
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
791
792
793
794
795
796
797
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
            and k_seqlen == q_seqlen
        ):
798
799
800
801
802
803
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
804
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
805
806
807
808
809
810
811
812
813
814
815
816
817
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
818
819
820
        return SoftmaxPrimitive.forward_lowering(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
821
822
823
824

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
825
826
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
827
828
829
830
831
832
833
834

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
835
836
            scale_factor=scale_factor,
        )
837
838
839
840

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
841
842
            scale_factor, mesh, arg_infos, result_infos
        )
843
844
845
846

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
847
848
849
850
851
852
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
853
854
855
856
857
858
859
860
861
862


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
863
864
    if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
865
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
866
867
        logits, scale_factor=scale_factor
    )
868
869
870
871
872
873


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
874

875
876
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
877
    impl_static_args = (2,)  # scale_factor
878
879
880
881
    inner_primitive = None
    outer_primitive = None

    @staticmethod
882
883
884
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
885
886
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
887
888
            batch, heads, q_seqlen, k_seqlen, dtype
        )
889
890
891
892
893
894
895
896
897
898
899
900
901

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
902
903
904
905
906
907
908
        out = SoftmaxPrimitive.backward_lowering(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
            ctx,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
909
910
911
912
913
914
915
916
917

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
918
919
            scale_factor=scale_factor,
        )
920
921
922
923
924
925
926
927

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
928
929
            scale_factor=scale_factor,
        )
930
931
932
933

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
934
935
            scale_factor, mesh, arg_infos, result_infos
        )
936
937
938
939

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
940
941
942
943
944
945
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
946
947
948
949
950


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


951
def scaled_upper_triang_masked_softmax_bwd(
952
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
953
) -> jnp.ndarray:
954
955
956
957
    """
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
    """
958
959
960
961
962
    if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
        )
        return vjp_func(dz)[0]
963
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
964
965
        dz, softmax_out, scale_factor=scale_factor
    )