softmax.py 32.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""JAX/TE custom ops for softmax"""
from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings
9
from packaging import version
10

11
import jax
12
import jax.numpy as jnp
13
from jax import dtypes
14
15
16
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

17
import transformer_engine_jax
18
19
20

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
21
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
22
23
from ..softmax import SoftmaxType

24
25
26
27
28
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

29

30
31
32
33
34
35
36
37
38
39
40
__all__ = [
    "scaled_softmax_fwd",
    "scaled_softmax_bwd",
    "scaled_masked_softmax_fwd",
    "scaled_masked_softmax_bwd",
    "scaled_upper_triang_masked_softmax_fwd",
    "scaled_upper_triang_masked_softmax_bwd",
    "is_softmax_kernel_available",
]


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
    return jax.nn.softmax(scale_factor * logits)


def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
    if mask is not None:
        logits += jax.lax.select(
            mask > 0,
            jnp.full(mask.shape, -1e10).astype(logits.dtype),
            jnp.full(mask.shape, 0.0).astype(logits.dtype),
        )
    return jax.nn.softmax(logits * scale_factor)


def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
    mask = 1 - jnp.tril(jnp.ones_like(logits))
    logits += jax.lax.select(
        mask > 0,
        jnp.full(mask.shape, -1e10).astype(logits.dtype),
        jnp.full(mask.shape, 0.0).astype(logits.dtype),
    )
    return jax.nn.softmax(logits * scale_factor)


65
66
67
68
69
70
71
72
def is_softmax_kernel_available(
    softmax_type: SoftmaxType,
    batch: int,
    heads: int,
    q_seqlen: int,
    k_seqlen: int,
    dtype: jnp.dtype,
):
73
74
    """check softmax available"""
    if softmax_type is SoftmaxType.SCALED:
75
76
77
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
78
    if softmax_type is SoftmaxType.SCALED_MASKED:
79
80
81
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
82
83
    if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
84
85
            batch, heads, q_seqlen, k_seqlen, dtype
        )
86
87
88
89
90
91
92
93

    raise NotImplementedError


class SoftmaxPrimitive(BasePrimitive):
    """
    Softmax Primitive
    """
94

95
96
97
98
99
    max_k_seqlen_supported = 16384
    name = "te_softmax_internal_placeholder"

    @staticmethod
    @abstractmethod
100
101
102
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
103
104
105
106
107
108
109
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError

    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
110
        threads_per_block = 128  # Depends on the kernel implmentation
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block

    @staticmethod
    def forward_abstract(logits_aval, scale_factor):
        """
        softmax_forward abstract
        """
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

134
        out_aval = logits_aval
135
136
137
138
139
140
141
        return out_aval

    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        if is_ffi_enabled():
            ffi_name = name + "_ffi"
            out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor)
        else:
            (i_aval,) = ctx.avals_in
            i_type = ir.RankedTensorType(logits.type)
            i_shape = i_type.shape
            # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
            batch = reduce(operator.mul, i_shape[:-3])
            pad_batch = batch
            heads = i_shape[-3]
            q_seqlen = i_shape[-2]
            k_seqlen = i_shape[-1]

            out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
            operands = [logits]
            operand_shapes = [i_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_softmax_descriptor(
                batch,
                pad_batch,
                heads,
                q_seqlen,
                k_seqlen,
                jax_dtype_to_te_dtype(i_aval.dtype),
                scale_factor,
            )
170

171
            out = custom_caller(name, args, opaque, False)
172

173
        return out
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output

    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
190
191
        (logits,) = batched_args
        (logits_bdim,) = batch_dims
192
193
194
195
196
197
198
199
200

        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims

    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
201
        del scale_factor, result_infos  # Unused.
202
203
204
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
205
206
207
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
208
209
210
211
212
213
214
215
216
217
218
219
220
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        return out_sharding

    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward partitioning
        """
        del result_infos
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
221
222
223
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
224
225
226
227
228
229
230
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings

    @staticmethod
231
232
233
    def backward_abstract(
        dz_aval, softmax_out_aval, scale_factor=None
    ):  # pylint: disable=unused-argument
234
235
236
237
238
239
240
241
242
243
244
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]

        assert dz_aval.shape == softmax_out_aval.shape

245
        dx_aval = dz_aval
246
247
248
249
250
251
252
        return dx_aval

    @staticmethod
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
        """
        softmax_backward lowering rules
        """
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if is_ffi_enabled():
            ffi_name = name + "_ffi"
            out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor)
        else:
            dz_aval, _ = ctx.avals_in

            dz_type = ir.RankedTensorType(dz.type)
            dz_shape = dz_type.shape

            # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
            batch = reduce(operator.mul, dz_shape[:-3])
            pad_batch = batch  # unused
            heads = dz_shape[-3]
            q_seqlen = dz_shape[-2]
            k_seqlen = dz_shape[-1]

            softmax_out_type = ir.RankedTensorType(softmax_out.type)
            softmax_out_shape = softmax_out_type.shape

            out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
            operands = [dz, softmax_out]
            operand_shapes = [dz_shape, softmax_out_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_softmax_descriptor(
                batch,
                pad_batch,
                heads,
                q_seqlen,
                k_seqlen,
                jax_dtype_to_te_dtype(dz_aval.dtype),
                scale_factor,
            )
286

287
            out = custom_caller(name, args, opaque, False)
288

289
        return out
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

    @staticmethod
    def backward_impl(primitive, dz, softmax_out, scale_factor):
        """
        softmax_backward implementation
        """
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx

    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims

        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward infer_sharding_from_operands
        """
317
        del scale_factor, result_infos  # Unused.
318
319
320
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
321
322
323
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        return dx_sharding

    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
339
340
341
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
358

359
360
    name = "te_scaled_softmax_forward"
    multiple_results = False
361
    impl_static_args = (1,)  # scale_factor
362
363
364
365
    inner_primitive = None
    outer_primitive = None

    @staticmethod
366
367
368
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
369
370
371
372
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
373
374
375
376
377
378
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
379
380
381
382
383
384
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
385
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
386
387
388
389
390
391
392
393
394
395
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
396
397
398
        return SoftmaxPrimitive.forward_lowering(
            ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
399
400
401

    @staticmethod
    def impl(logits, scale_factor):
402
403
404
        return SoftmaxPrimitive.forward_impl(
            ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
405
406
407
408

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
409
410
411
412
413
414
        return SoftmaxPrimitive.forward_batcher(
            ScaledSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
415
416
417
418

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
419
420
            scale_factor, mesh, arg_infos, result_infos
        )
421
422
423

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
424
425
426
        return ScaledSoftmaxFwdPrimitive.forward_partition(
            ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
427
428
429
430
431
432
433
434
435
436


register_primitive(ScaledSoftmaxFwdPrimitive)


def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
437
438
    if not ScaledSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_softmax(logits, scale_factor)
439
440
441
442
443
444
445
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)


class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Bwd Primitive
    """
446

447
448
    name = "te_scaled_softmax_backward"
    multiple_results = False
449
    impl_static_args = (2,)  # scale_factor
450
451
452
453
    inner_primitive = None
    outer_primitive = None

    @staticmethod
454
455
456
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
457
        """Check Softmax kernel availability based on size"""
458
459
460
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
461
462
463
464
465
466
467
468
469
470
471
472
473

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, scale_factor):
        """
        te_scaled_softmax_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
474
475
476
        out = SoftmaxPrimitive.backward_lowering(
            ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
477
478
479
480
481

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
482
483
484
        return SoftmaxPrimitive.backward_impl(
            ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor
        )
485
486
487
488

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
489
490
491
492
493
494
        return SoftmaxPrimitive.backward_batcher(
            ScaledSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
495
496
497
498

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
499
500
            scale_factor, mesh, arg_infos, result_infos
        )
501
502
503

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
504
505
506
        return ScaledSoftmaxBwdPrimitive.backward_partition(
            ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
507
508
509
510
511


register_primitive(ScaledSoftmaxBwdPrimitive)


512
def scaled_softmax_bwd(
513
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
514
) -> jnp.ndarray:
515
516
517
518
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
519
520
521
522
    if not ScaledSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
        return vjp_func(dz)[0]

523
524
525
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
526
527
528
529
530
531


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
532

533
534
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
535
    impl_static_args = (2,)  # scale_factor
536
537
538
539
    inner_primitive = None
    outer_primitive = None

    @staticmethod
540
541
542
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
543
544
545
546
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
547
548
549
550
551
552
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
553
554
555
556
557
558
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
559
    def abstract(logits_aval, mask_aval, scale_factor):  # pylint: disable=unused-argument
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        """
        te_scaled_masked_softmax_forward abstract
        """

        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
        ]
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
581
582
        assert pad_batch in (1, batch)  # 1 means broadcast
        assert mask_shape[-3] == 1  # 1 means broadcast
583
584
585
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

586
        out_aval = logits_aval
587
588
589
590
591
592
593
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        if is_ffi_enabled():
            ffi_name = "te_scaled_masked_softmax_forward_ffi"
            out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor)
        else:
            logits_aval, _ = ctx.avals_in
            i_type = ir.RankedTensorType(logits.type)
            i_shape = i_type.shape
            # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
            batch = reduce(operator.mul, i_shape[:-3])
            heads = i_shape[-3]
            q_seqlen = i_shape[-2]
            k_seqlen = i_shape[-1]

            mask_type = ir.RankedTensorType(mask.type)
            mask_shape = mask_type.shape
            pad_batch = reduce(operator.mul, mask_shape[:-3])

            out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
            operands = [logits, mask]
            operand_shapes = [i_shape, mask_shape]
            args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

            opaque = transformer_engine_jax.pack_softmax_descriptor(
                batch,
                pad_batch,
                heads,
                q_seqlen,
                k_seqlen,
                jax_dtype_to_te_dtype(logits_aval.dtype),
                scale_factor,
            )
625

626
            out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
627

628
        return out
629
630
631
632

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
633
634
635
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(
            logits, mask, scale_factor=scale_factor
        )
636
637
638
639
640
641
642
643
644
645
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
646
647
648
649
650
651
        return (
            ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
                logits, mask, scale_factor=scale_factor
            ),
            out_bdims,
        )
652
653
654
655

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
656
657
            scale_factor, mesh, arg_infos, result_infos
        )
658
659
660
661

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
662
663
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
664
665
666
667
668


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


669
670
671
def scaled_masked_softmax_fwd(
    logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
672
673
674
675
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
676
677
    if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_masked_softmax(logits, mask, scale_factor)
678
679
680
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, mask, scale_factor=scale_factor
    )
681
682
683
684
685
686


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
687

688
689
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
690
    impl_static_args = (2,)  # scale_factor
691
692
693
694
    inner_primitive = None
    outer_primitive = None

    @staticmethod
695
696
697
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
698
        """Check Softmax kernel availability based on size"""
699
700
701
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
702
703
704
705
706
707
708
709
710
711
712
713
714

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
715
716
717
        out = SoftmaxPrimitive.backward_lowering(
            ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
718
719
720
721
722

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
723
724
725
726
727
728
        return SoftmaxPrimitive.backward_impl(
            ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
729
730
731
732

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
733
734
735
736
737
738
        return SoftmaxPrimitive.backward_batcher(
            ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
739
740
741
742

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
743
744
            scale_factor, mesh, arg_infos, result_infos
        )
745
746
747
748

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
749
750
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
751
752
753
754
755


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


756
def scaled_masked_softmax_bwd(
757
758
759
760
761
    dz: jnp.ndarray,
    softmax_out: jnp.ndarray,
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
762
) -> jnp.ndarray:
763
764
765
766
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
767
768
769
770
771
    if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
        )
        return vjp_func(dz)[0]
772
773
774
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
775
776
777
778
779
780


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
781

782
783
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
784
    impl_static_args = (1,)  # scale_factor
785
786
787
788
    inner_primitive = None
    outer_primitive = None

    @staticmethod
789
790
791
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
792
793
794
795
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
796
797
798
799
800
801
802
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
            and k_seqlen == q_seqlen
        ):
803
804
805
806
807
808
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
809
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
810
811
812
813
814
815
816
817
818
819
820
821
822
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
823
824
825
        return SoftmaxPrimitive.forward_lowering(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
826
827
828
829

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
830
831
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
832
833
834
835
836
837
838
839

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
840
841
            scale_factor=scale_factor,
        )
842
843
844
845

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
846
847
            scale_factor, mesh, arg_infos, result_infos
        )
848
849
850
851

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
852
853
854
855
856
857
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
858
859
860
861
862
863
864
865
866
867


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
868
869
    if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
870
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
871
872
        logits, scale_factor=scale_factor
    )
873
874
875
876
877
878


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
879

880
881
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
882
    impl_static_args = (2,)  # scale_factor
883
884
885
886
    inner_primitive = None
    outer_primitive = None

    @staticmethod
887
888
889
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
890
891
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
892
893
            batch, heads, q_seqlen, k_seqlen, dtype
        )
894
895
896
897
898
899
900
901
902
903
904
905
906

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
907
908
909
910
911
912
913
        out = SoftmaxPrimitive.backward_lowering(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
            ctx,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
914
915
916
917
918
919
920
921
922

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
923
924
            scale_factor=scale_factor,
        )
925
926
927
928
929
930
931
932

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
933
934
            scale_factor=scale_factor,
        )
935
936
937
938

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
939
940
            scale_factor, mesh, arg_infos, result_infos
        )
941
942
943
944

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
945
946
947
948
949
950
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
951
952
953
954
955


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


956
def scaled_upper_triang_masked_softmax_bwd(
957
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
958
) -> jnp.ndarray:
959
960
961
962
    """
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
    """
963
964
965
966
967
    if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
        )
        return vjp_func(dz)[0]
968
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
969
970
        dz, softmax_out, scale_factor=scale_factor
    )