softmax.py 29.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
#
# See LICENSE for license information.
"""JAX/TE custom ops for softmax"""
from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings
9
from packaging import version
10

11
import jax
12
import jax.numpy as jnp
13
from jax import dtypes
14
15
16
from jax.sharding import PartitionSpec, NamedSharding

from .base import BasePrimitive, register_primitive
17
from .misc import get_padded_spec, check_valid_batch_dims
18
19
from ..softmax import SoftmaxType

20
21
22
23
24
if version.parse(jax.__version__) >= version.parse("0.5.0"):
    from jax import ffi  # pylint: disable=ungrouped-imports
else:
    from jax.extend import ffi  # pylint: disable=ungrouped-imports

25

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
__all__ = [
    "scaled_softmax_fwd",
    "scaled_softmax_bwd",
    "scaled_masked_softmax_fwd",
    "scaled_masked_softmax_bwd",
    "scaled_upper_triang_masked_softmax_fwd",
    "scaled_upper_triang_masked_softmax_bwd",
    "is_softmax_kernel_available",
]


def is_softmax_kernel_available(
    softmax_type: SoftmaxType,
    batch: int,
    heads: int,
    q_seqlen: int,
    k_seqlen: int,
    dtype: jnp.dtype,
):
45
46
    """check softmax available"""
    if softmax_type is SoftmaxType.SCALED:
47
48
49
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
50
    if softmax_type is SoftmaxType.SCALED_MASKED:
51
52
53
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
54
55
    if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
56
57
            batch, heads, q_seqlen, k_seqlen, dtype
        )
58
59
60
61
62
63
64
65

    raise NotImplementedError


class SoftmaxPrimitive(BasePrimitive):
    """
    Softmax Primitive
    """
66

67
68
69
70
71
    max_k_seqlen_supported = 16384
    name = "te_softmax_internal_placeholder"

    @staticmethod
    @abstractmethod
72
73
74
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
75
76
77
78
79
80
81
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError

    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
82
        threads_per_block = 128  # Depends on the kernel implmentation
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block

    @staticmethod
    def forward_abstract(logits_aval, scale_factor):
        """
        softmax_forward abstract
        """
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

106
        out_aval = logits_aval
107
108
109
110
111
112
113
        return out_aval

    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
114
        return ffi.ffi_lowering(name)(ctx, logits, scale_factor=scale_factor)
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output

    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
131
132
        (logits,) = batched_args
        (logits_bdim,) = batch_dims
133
134
135
136
137
138
139
140
141

        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims

    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
142
        del scale_factor, result_infos  # Unused.
143
144
145
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
146
147
148
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
149
150
151
152
153
154
155
156
157
158
159
160
161
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        return out_sharding

    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward partitioning
        """
        del result_infos
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
162
163
164
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
165
166
167
168
169
170
171
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings

    @staticmethod
172
173
174
    def backward_abstract(
        dz_aval, softmax_out_aval, scale_factor=None
    ):  # pylint: disable=unused-argument
175
176
177
178
179
180
181
182
183
184
185
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]

        assert dz_aval.shape == softmax_out_aval.shape

186
        dx_aval = dz_aval
187
188
189
190
191
192
193
        return dx_aval

    @staticmethod
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
        """
        softmax_backward lowering rules
        """
194
        return ffi.ffi_lowering(name)(ctx, dz, softmax_out, scale_factor=scale_factor)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

    @staticmethod
    def backward_impl(primitive, dz, softmax_out, scale_factor):
        """
        softmax_backward implementation
        """
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx

    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims

        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward infer_sharding_from_operands
        """
222
        del scale_factor, result_infos  # Unused.
223
224
225
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
226
227
228
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        return dx_sharding

    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
244
245
246
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
263

264
    name = "te_scaled_softmax_forward_ffi"
265
    multiple_results = False
266
    impl_static_args = (1,)  # scale_factor
267
268
269
270
    inner_primitive = None
    outer_primitive = None

    @staticmethod
271
272
273
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
274
275
276
277
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
278
279
280
281
282
283
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
284
285
286
287
288
289
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
290
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
291
292
293
294
295
296
297
298
299
300
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
301
302
303
        return SoftmaxPrimitive.forward_lowering(
            ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
304
305
306

    @staticmethod
    def impl(logits, scale_factor):
307
308
309
        return SoftmaxPrimitive.forward_impl(
            ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
310
311
312
313

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
314
315
316
317
318
319
        return SoftmaxPrimitive.forward_batcher(
            ScaledSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
320
321
322
323

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
324
325
            scale_factor, mesh, arg_infos, result_infos
        )
326
327
328

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
329
330
331
        return ScaledSoftmaxFwdPrimitive.forward_partition(
            ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
332

333
334
335
336
337
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "... -> ..."

338
339
340
341
342
343
344
345

register_primitive(ScaledSoftmaxFwdPrimitive)


class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Bwd Primitive
    """
346

347
    name = "te_scaled_softmax_backward_ffi"
348
    multiple_results = False
349
    impl_static_args = (2,)  # scale_factor
350
351
352
353
    inner_primitive = None
    outer_primitive = None

    @staticmethod
354
355
356
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
357
        """Check Softmax kernel availability based on size"""
358
359
360
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
361
362
363
364
365
366
367
368
369
370
371
372
373

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, scale_factor):
        """
        te_scaled_softmax_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
374
375
376
        out = SoftmaxPrimitive.backward_lowering(
            ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
377
378
379
380
381

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
382
383
384
        return SoftmaxPrimitive.backward_impl(
            ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor
        )
385
386
387
388

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
389
390
391
392
393
394
        return SoftmaxPrimitive.backward_batcher(
            ScaledSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
395
396
397
398

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
399
400
            scale_factor, mesh, arg_infos, result_infos
        )
401
402
403

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
404
405
406
        return ScaledSoftmaxBwdPrimitive.backward_partition(
            ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
407

408
409
410
411
412
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

413
414
415
416

register_primitive(ScaledSoftmaxBwdPrimitive)


417
def scaled_softmax_bwd(
418
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
419
) -> jnp.ndarray:
420
421
422
423
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
424
425
426
427
    if not ScaledSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits)
        return vjp_func(dz)[0]

428
429
430
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
431
432
433
434
435
436


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
437

438
    name = "te_scaled_masked_softmax_forward_ffi"
439
    multiple_results = False
440
    impl_static_args = (2,)  # scale_factor
441
442
443
444
    inner_primitive = None
    outer_primitive = None

    @staticmethod
445
446
447
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
448
449
450
451
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
452
453
454
455
456
457
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
458
459
460
461
462
463
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
464
    def abstract(logits_aval, mask_aval, scale_factor):  # pylint: disable=unused-argument
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        """
        te_scaled_masked_softmax_forward abstract
        """

        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
        ]
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
486
487
        assert pad_batch in (1, batch)  # 1 means broadcast
        assert mask_shape[-3] == 1  # 1 means broadcast
488
489
490
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

491
        out_aval = logits_aval
492
493
494
495
496
497
498
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """
499
500
501
        return ffi.ffi_lowering(ScaledMaskedSoftmaxFwdPrimitive.name)(
            ctx, logits, mask, scale_factor=scale_factor
        )
502
503
504
505

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
506
507
508
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(
            logits, mask, scale_factor=scale_factor
        )
509
510
511
512
513
514
515
516
517
518
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
519
520
521
522
523
524
        return (
            ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
                logits, mask, scale_factor=scale_factor
            ),
            out_bdims,
        )
525
526
527
528

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
529
530
            scale_factor, mesh, arg_infos, result_infos
        )
531
532
533
534

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
535
536
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
537

538
539
540
541
542
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "...1, ...2 -> ...1"

543
544
545
546
547
548
549
550

register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
551

552
    name = "te_scaled_masked_softmax_backward_ffi"
553
    multiple_results = False
554
    impl_static_args = (2,)  # scale_factor
555
556
557
558
    inner_primitive = None
    outer_primitive = None

    @staticmethod
559
560
561
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
562
        """Check Softmax kernel availability based on size"""
563
564
565
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
566
567
568
569
570
571
572
573
574
575
576
577
578

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
579
        return SoftmaxPrimitive.backward_lowering(
580
581
            ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
582
583
584

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
585
586
587
588
589
590
        return SoftmaxPrimitive.backward_impl(
            ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
591
592
593
594

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
595
596
597
598
599
600
        return SoftmaxPrimitive.backward_batcher(
            ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
601
602
603
604

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
605
606
            scale_factor, mesh, arg_infos, result_infos
        )
607
608
609
610

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
611
612
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
613

614
615
616
617
618
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

619
620
621
622
623
624
625
626

register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
627

628
    name = "te_scaled_upper_triang_masked_softmax_forward_ffi"
629
    multiple_results = False
630
    impl_static_args = (1,)  # scale_factor
631
632
633
634
    inner_primitive = None
    outer_primitive = None

    @staticmethod
635
636
637
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
638
639
640
641
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
642
643
644
645
646
647
648
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
            and k_seqlen == q_seqlen
        ):
649
650
651
652
653
654
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
655
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
656
657
658
659
660
661
662
663
664
665
666
667
668
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
669
670
671
        return SoftmaxPrimitive.forward_lowering(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
672
673
674
675

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
676
677
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
678
679
680
681
682
683
684
685

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
686
687
            scale_factor=scale_factor,
        )
688
689
690
691

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
692
693
            scale_factor, mesh, arg_infos, result_infos
        )
694
695
696
697

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
698
699
700
701
702
703
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
704

705
706
707
708
709
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "... -> ..."

710
711
712
713
714
715
716
717

register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
718

719
    name = "te_scaled_upper_triang_masked_softmax_backward_ffi"
720
    multiple_results = False
721
    impl_static_args = (2,)  # scale_factor
722
723
724
725
    inner_primitive = None
    outer_primitive = None

    @staticmethod
726
727
728
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
729
730
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
731
732
            batch, heads, q_seqlen, k_seqlen, dtype
        )
733
734
735
736
737
738
739
740
741
742
743
744
745

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
746
        return SoftmaxPrimitive.backward_lowering(
747
748
749
750
751
752
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
            ctx,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
753
754
755
756
757
758
759

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
760
761
            scale_factor=scale_factor,
        )
762
763
764
765
766
767
768
769

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
770
771
            scale_factor=scale_factor,
        )
772
773
774
775

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
776
777
            scale_factor, mesh, arg_infos, result_infos
        )
778
779
780
781

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
782
783
784
785
786
787
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
788

789
790
791
792
793
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

794
795
796
797

register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
    return jax.nn.softmax(scale_factor * logits)


def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
    if mask is not None:
        logits += jax.lax.select(
            mask > 0,
            jnp.full(mask.shape, -1e10).astype(logits.dtype),
            jnp.full(mask.shape, 0.0).astype(logits.dtype),
        )
    return jax.nn.softmax(logits * scale_factor)


def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
    mask = 1 - jnp.tril(jnp.ones_like(logits))
    logits += jax.lax.select(
        mask > 0,
        jnp.full(mask.shape, -1e10).astype(logits.dtype),
        jnp.full(mask.shape, 0.0).astype(logits.dtype),
    )
    return jax.nn.softmax(logits * scale_factor)


def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_softmax(logits, scale_factor)
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)


def scaled_masked_softmax_fwd(
    logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_masked_softmax(logits, mask, scale_factor)
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, mask, scale_factor=scale_factor
    )


def scaled_masked_softmax_bwd(
    dz: jnp.ndarray,
    softmax_out: jnp.ndarray,
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
        )
        return vjp_func(dz)[0]
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
        return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor
    )


879
def scaled_upper_triang_masked_softmax_bwd(
880
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
881
) -> jnp.ndarray:
882
883
884
885
    """
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
    """
886
887
888
889
890
    if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
            partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
        )
        return vjp_func(dz)[0]
891
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
892
893
        dz, softmax_out, scale_factor=scale_factor
    )