softmax.py 31.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for softmax"""
from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings

10
import jax
11
12
13
14
15
16
17
18
19
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding

from transformer_engine import transformer_engine_jax

from .base import BasePrimitive, register_primitive
from .custom_call import custom_caller, CustomCallArgsWrapper
20
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype
21
22
23
from ..softmax import SoftmaxType


24
25
26
27
28
29
30
31
32
33
34
__all__ = [
    "scaled_softmax_fwd",
    "scaled_softmax_bwd",
    "scaled_masked_softmax_fwd",
    "scaled_masked_softmax_bwd",
    "scaled_upper_triang_masked_softmax_fwd",
    "scaled_upper_triang_masked_softmax_bwd",
    "is_softmax_kernel_available",
]


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
    return jax.nn.softmax(scale_factor * logits)


def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
    if mask is not None:
        logits += jax.lax.select(
            mask > 0,
            jnp.full(mask.shape, -1e10).astype(logits.dtype),
            jnp.full(mask.shape, 0.0).astype(logits.dtype),
        )
    return jax.nn.softmax(logits * scale_factor)


def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
    mask = 1 - jnp.tril(jnp.ones_like(logits))
    logits += jax.lax.select(
        mask > 0,
        jnp.full(mask.shape, -1e10).astype(logits.dtype),
        jnp.full(mask.shape, 0.0).astype(logits.dtype),
    )
    return jax.nn.softmax(logits * scale_factor)


59
60
61
62
63
64
65
66
def is_softmax_kernel_available(
    softmax_type: SoftmaxType,
    batch: int,
    heads: int,
    q_seqlen: int,
    k_seqlen: int,
    dtype: jnp.dtype,
):
67
68
    """check softmax available"""
    if softmax_type is SoftmaxType.SCALED:
69
70
71
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
72
    if softmax_type is SoftmaxType.SCALED_MASKED:
73
74
75
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
76
77
    if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
78
79
            batch, heads, q_seqlen, k_seqlen, dtype
        )
80
81
82
83
84
85
86
87

    raise NotImplementedError


class SoftmaxPrimitive(BasePrimitive):
    """
    Softmax Primitive
    """
88

89
90
91
92
93
    max_k_seqlen_supported = 16384
    name = "te_softmax_internal_placeholder"

    @staticmethod
    @abstractmethod
94
95
96
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
97
98
99
100
101
102
103
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError

    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
104
        threads_per_block = 128  # Depends on the kernel implmentation
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block

    @staticmethod
    def forward_abstract(logits_aval, scale_factor):
        """
        softmax_forward abstract
        """
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
136
        (i_aval,) = ctx.avals_in
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        pad_batch = batch
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits]
        operand_shapes = [i_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

151
152
153
154
155
156
157
158
159
        opaque = transformer_engine_jax.pack_softmax_descriptor(
            batch,
            pad_batch,
            heads,
            q_seqlen,
            k_seqlen,
            jax_dtype_to_te_dtype(i_aval.dtype),
            scale_factor,
        )
160
161
162

        out = custom_caller(name, args, opaque, False)

163
        return out
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output

    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
180
181
        (logits,) = batched_args
        (logits_bdim,) = batch_dims
182
183
184
185
186
187
188
189
190

        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims

    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
191
        del scale_factor, result_infos  # Unused.
192
193
194
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
195
196
197
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
198
199
200
201
202
203
204
205
206
207
208
209
210
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        return out_sharding

    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward partitioning
        """
        del result_infos
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
211
212
213
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
214
215
216
217
218
219
220
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings

    @staticmethod
221
222
223
    def backward_abstract(
        dz_aval, softmax_out_aval, scale_factor=None
    ):  # pylint: disable=unused-argument
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]

        assert dz_aval.shape == softmax_out_aval.shape

        dx_aval = core.raise_to_shaped(dz_aval)
        return dx_aval

    @staticmethod
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
        """
        softmax_backward lowering rules
        """
        dz_aval, _ = ctx.avals_in

        dz_type = ir.RankedTensorType(dz.type)
        dz_shape = dz_type.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, dz_shape[:-3])
250
        pad_batch = batch  # unused
251
252
253
254
255
256
257
258
259
260
261
262
263
        heads = dz_shape[-3]
        q_seqlen = dz_shape[-2]
        k_seqlen = dz_shape[-1]

        softmax_out_type = ir.RankedTensorType(softmax_out.type)
        softmax_out_shape = softmax_out_type.shape

        out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
        operands = [dz, softmax_out]
        operand_shapes = [dz_shape, softmax_out_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        opaque = transformer_engine_jax.pack_softmax_descriptor(
264
265
266
267
268
269
270
271
            batch,
            pad_batch,
            heads,
            q_seqlen,
            k_seqlen,
            jax_dtype_to_te_dtype(dz_aval.dtype),
            scale_factor,
        )
272
273
274

        out = custom_caller(name, args, opaque, False)

275
        return out
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    @staticmethod
    def backward_impl(primitive, dz, softmax_out, scale_factor):
        """
        softmax_backward implementation
        """
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx

    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims

        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward infer_sharding_from_operands
        """
303
        del scale_factor, result_infos  # Unused.
304
305
306
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
307
308
309
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        return dx_sharding

    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
325
326
327
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
344

345
346
    name = "te_scaled_softmax_forward"
    multiple_results = False
347
    impl_static_args = (1,)  # scale_factor
348
349
350
351
    inner_primitive = None
    outer_primitive = None

    @staticmethod
352
353
354
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
355
356
357
358
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
359
360
361
362
363
364
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
365
366
367
368
369
370
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
371
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
372
373
374
375
376
377
378
379
380
381
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
382
383
384
        return SoftmaxPrimitive.forward_lowering(
            ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
385
386
387

    @staticmethod
    def impl(logits, scale_factor):
388
389
390
        return SoftmaxPrimitive.forward_impl(
            ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
391
392
393
394

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
395
396
397
398
399
400
        return SoftmaxPrimitive.forward_batcher(
            ScaledSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
401
402
403
404

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
405
406
            scale_factor, mesh, arg_infos, result_infos
        )
407
408
409

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
410
411
412
        return ScaledSoftmaxFwdPrimitive.forward_partition(
            ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
413
414
415
416
417
418
419
420
421
422


register_primitive(ScaledSoftmaxFwdPrimitive)


def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
423
424
    if not ScaledSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_softmax(logits, scale_factor)
425
426
427
428
429
430
431
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)


class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Bwd Primitive
    """
432

433
434
    name = "te_scaled_softmax_backward"
    multiple_results = False
435
    impl_static_args = (2,)  # scale_factor
436
437
438
439
    inner_primitive = None
    outer_primitive = None

    @staticmethod
440
441
442
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
443
        """Check Softmax kernel availability based on size"""
444
445
446
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
447
448
449
450
451
452
453
454
455
456
457
458
459

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, scale_factor):
        """
        te_scaled_softmax_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
460
461
462
        out = SoftmaxPrimitive.backward_lowering(
            ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
463
464
465
466
467

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
468
469
470
        return SoftmaxPrimitive.backward_impl(
            ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor
        )
471
472
473
474

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
475
476
477
478
479
480
        return SoftmaxPrimitive.backward_batcher(
            ScaledSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
481
482
483
484

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
485
486
            scale_factor, mesh, arg_infos, result_infos
        )
487
488
489

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
490
491
492
        return ScaledSoftmaxBwdPrimitive.backward_partition(
            ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
493
494
495
496
497


register_primitive(ScaledSoftmaxBwdPrimitive)


498
def scaled_softmax_bwd(
499
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
500
) -> jnp.ndarray:
501
502
503
504
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
505
506
507
508
    if not ScaledSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
        return vjp_func(dz)[0]

509
510
511
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
512
513
514
515
516
517


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
518

519
520
    name = "te_scaled_masked_softmax_forward"
    multiple_results = False
521
    impl_static_args = (2,)  # scale_factor
522
523
524
525
    inner_primitive = None
    outer_primitive = None

    @staticmethod
526
527
528
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
529
530
531
532
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
533
534
535
536
537
538
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
539
540
541
542
543
544
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
545
    def abstract(logits_aval, mask_aval, scale_factor):  # pylint: disable=unused-argument
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        """
        te_scaled_masked_softmax_forward abstract
        """

        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
        ]
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
567
568
        assert pad_batch in (1, batch)  # 1 means broadcast
        assert mask_shape[-3] == 1  # 1 means broadcast
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

        out_aval = core.raise_to_shaped(logits_aval)
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """

        logits_aval, _ = ctx.avals_in
        i_type = ir.RankedTensorType(logits.type)
        i_shape = i_type.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        heads = i_shape[-3]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]

        mask_type = ir.RankedTensorType(mask.type)
        mask_shape = mask_type.shape
        pad_batch = reduce(operator.mul, mask_shape[:-3])

        out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
        operands = [logits, mask]
        operand_shapes = [i_shape, mask_shape]
        args = CustomCallArgsWrapper(out_types, operands, operand_shapes)

        opaque = transformer_engine_jax.pack_softmax_descriptor(
600
601
602
603
604
605
606
607
            batch,
            pad_batch,
            heads,
            q_seqlen,
            k_seqlen,
            jax_dtype_to_te_dtype(logits_aval.dtype),
            scale_factor,
        )
608
609
610

        out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)

611
        return out
612
613
614
615

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
616
617
618
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(
            logits, mask, scale_factor=scale_factor
        )
619
620
621
622
623
624
625
626
627
628
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
629
630
631
632
633
634
        return (
            ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
                logits, mask, scale_factor=scale_factor
            ),
            out_bdims,
        )
635
636
637
638

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
639
640
            scale_factor, mesh, arg_infos, result_infos
        )
641
642
643
644

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
645
646
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
647
648
649
650
651


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


652
653
654
def scaled_masked_softmax_fwd(
    logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
655
656
657
658
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
659
660
    if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_masked_softmax(logits, mask, scale_factor)
661
662
663
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, mask, scale_factor=scale_factor
    )
664
665
666
667
668
669


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
670

671
672
    name = "te_scaled_masked_softmax_backward"
    multiple_results = False
673
    impl_static_args = (2,)  # scale_factor
674
675
676
677
    inner_primitive = None
    outer_primitive = None

    @staticmethod
678
679
680
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
681
        """Check Softmax kernel availability based on size"""
682
683
684
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
685
686
687
688
689
690
691
692
693
694
695
696
697

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
698
699
700
        out = SoftmaxPrimitive.backward_lowering(
            ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
701
702
703
704
705

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
706
707
708
709
710
711
        return SoftmaxPrimitive.backward_impl(
            ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
712
713
714
715

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
716
717
718
719
720
721
        return SoftmaxPrimitive.backward_batcher(
            ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
722
723
724
725

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
726
727
            scale_factor, mesh, arg_infos, result_infos
        )
728
729
730
731

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
732
733
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
734
735
736
737
738


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


739
def scaled_masked_softmax_bwd(
740
741
742
743
744
    dz: jnp.ndarray,
    softmax_out: jnp.ndarray,
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
745
) -> jnp.ndarray:
746
747
748
749
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
750
751
752
753
754
    if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
        )
        return vjp_func(dz)[0]
755
756
757
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
758
759
760
761
762
763


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
764

765
766
    name = "te_scaled_upper_triang_masked_softmax_forward"
    multiple_results = False
767
    impl_static_args = (1,)  # scale_factor
768
769
770
771
    inner_primitive = None
    outer_primitive = None

    @staticmethod
772
773
774
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
775
776
777
778
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
779
780
781
782
783
784
785
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
            and k_seqlen == q_seqlen
        ):
786
787
788
789
790
791
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
792
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
793
794
795
796
797
798
799
800
801
802
803
804
805
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
806
807
808
        return SoftmaxPrimitive.forward_lowering(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
809
810
811
812

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
813
814
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
815
816
817
818
819
820
821
822

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
823
824
            scale_factor=scale_factor,
        )
825
826
827
828

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
829
830
            scale_factor, mesh, arg_infos, result_infos
        )
831
832
833
834

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
835
836
837
838
839
840
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
841
842
843
844
845
846
847
848
849
850


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
851
852
    if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
853
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
854
855
        logits, scale_factor=scale_factor
    )
856
857
858
859
860
861


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
862

863
864
    name = "te_scaled_upper_triang_masked_softmax_backward"
    multiple_results = False
865
    impl_static_args = (2,)  # scale_factor
866
867
868
869
    inner_primitive = None
    outer_primitive = None

    @staticmethod
870
871
872
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
873
874
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
875
876
            batch, heads, q_seqlen, k_seqlen, dtype
        )
877
878
879
880
881
882
883
884
885
886
887
888
889

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
890
891
892
893
894
895
896
        out = SoftmaxPrimitive.backward_lowering(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
            ctx,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
897
898
899
900
901
902
903
904
905

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
906
907
            scale_factor=scale_factor,
        )
908
909
910
911
912
913
914
915

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
916
917
            scale_factor=scale_factor,
        )
918
919
920
921

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
922
923
            scale_factor, mesh, arg_infos, result_infos
        )
924
925
926
927

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
928
929
930
931
932
933
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
934
935
936
937
938


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


939
def scaled_upper_triang_masked_softmax_bwd(
940
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
941
) -> jnp.ndarray:
942
943
944
945
    """
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
    """
946
947
948
949
950
    if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
        )
        return vjp_func(dz)[0]
951
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
952
953
        dz, softmax_out, scale_factor=scale_factor
    )