"vllm/vscode:/vscode.git/clone" did not exist on "83658c8ace771617460f9e2d5f1cf6f811d6d6fb"
softmax.py 31 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
#
# See LICENSE for license information.
"""JAX/TE custom ops for softmax"""
from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
from jax.sharding import PartitionSpec, NamedSharding
14
from .attention import AttnSoftmaxType
15
16

from .base import BasePrimitive, register_primitive
17
from .misc import get_padded_spec, check_valid_batch_dims
18
from ..softmax import SoftmaxFusionType
19
20


21
22
23
24
25
26
27
28
__all__ = [
    "scaled_softmax_fwd",
    "scaled_softmax_bwd",
    "scaled_masked_softmax_fwd",
    "scaled_masked_softmax_bwd",
    "scaled_upper_triang_masked_softmax_fwd",
    "scaled_upper_triang_masked_softmax_bwd",
    "is_softmax_kernel_available",
29
30
31
    "jax_scaled_softmax",
    "jax_scaled_masked_softmax",
    "jax_scaled_upper_triang_masked_softmax",
32
33
34
35
]


def is_softmax_kernel_available(
36
37
    softmax_fusion_type: SoftmaxFusionType,
    softmax_type: AttnSoftmaxType,
38
39
40
41
42
43
    batch: int,
    heads: int,
    q_seqlen: int,
    k_seqlen: int,
    dtype: jnp.dtype,
):
44
    """check softmax available"""
45
46
47
48
    if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX:
        return False

    if softmax_fusion_type is SoftmaxFusionType.SCALED:
49
50
51
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
52
    if softmax_fusion_type is SoftmaxFusionType.SCALED_MASKED:
53
54
55
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
56
    if softmax_fusion_type is SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED:
57
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
58
59
            batch, heads, q_seqlen, k_seqlen, dtype
        )
60
61
62
63
64
65
66
67

    raise NotImplementedError


class SoftmaxPrimitive(BasePrimitive):
    """
    Softmax Primitive
    """
68

69
70
71
72
73
    max_k_seqlen_supported = 16384
    name = "te_softmax_internal_placeholder"

    @staticmethod
    @abstractmethod
74
75
76
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
77
78
79
80
81
82
83
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError

    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
84
        threads_per_block = 128  # Depends on the kernel implmentation
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block

    @staticmethod
    def forward_abstract(logits_aval, scale_factor):
        """
        softmax_forward abstract
        """
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

108
        out_aval = logits_aval
109
110
111
112
113
114
115
        return out_aval

    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
116
        return ffi.ffi_lowering(name)(ctx, logits, scale_factor=scale_factor)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output

    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
133
134
        (logits,) = batched_args
        (logits_bdim,) = batch_dims
135
136
137
138
139
140
141
142
143

        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims

    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
144
        del scale_factor, result_infos  # Unused.
145
146
147
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
148
149
150
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
151
152
153
154
155
156
157
158
159
160
161
162
163
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        return out_sharding

    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward partitioning
        """
        del result_infos
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
164
165
166
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
167
168
169
170
171
172
173
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings

    @staticmethod
174
175
176
    def backward_abstract(
        dz_aval, softmax_out_aval, scale_factor=None
    ):  # pylint: disable=unused-argument
177
178
179
180
181
182
183
184
185
186
187
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]

        assert dz_aval.shape == softmax_out_aval.shape

188
        dx_aval = dz_aval
189
190
191
192
193
194
195
        return dx_aval

    @staticmethod
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
        """
        softmax_backward lowering rules
        """
196
        return ffi.ffi_lowering(name)(ctx, dz, softmax_out, scale_factor=scale_factor)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

    @staticmethod
    def backward_impl(primitive, dz, softmax_out, scale_factor):
        """
        softmax_backward implementation
        """
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx

    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims

        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward infer_sharding_from_operands
        """
224
        del scale_factor, result_infos  # Unused.
225
226
227
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
228
229
230
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        return dx_sharding

    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
246
247
248
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
265

266
    name = "te_scaled_softmax_forward_ffi"
267
    multiple_results = False
268
    impl_static_args = (1,)  # scale_factor
269
270
271
272
    inner_primitive = None
    outer_primitive = None

    @staticmethod
273
274
275
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
276
277
278
279
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
280
281
282
283
284
285
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
286
287
288
289
290
291
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
292
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
293
294
295
296
297
298
299
300
301
302
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
303
304
305
        return SoftmaxPrimitive.forward_lowering(
            ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
306
307
308

    @staticmethod
    def impl(logits, scale_factor):
309
310
311
        return SoftmaxPrimitive.forward_impl(
            ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
312
313
314
315

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
316
317
318
319
320
321
        return SoftmaxPrimitive.forward_batcher(
            ScaledSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
322
323
324
325

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
326
327
            scale_factor, mesh, arg_infos, result_infos
        )
328
329
330

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
331
332
333
        return ScaledSoftmaxFwdPrimitive.forward_partition(
            ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
334

335
336
337
338
339
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "... -> ..."

340
341
342
343
344
345
346
347

register_primitive(ScaledSoftmaxFwdPrimitive)


class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Bwd Primitive
    """
348

349
    name = "te_scaled_softmax_backward_ffi"
350
    multiple_results = False
351
    impl_static_args = (2,)  # scale_factor
352
353
354
355
    inner_primitive = None
    outer_primitive = None

    @staticmethod
356
357
358
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
359
        """Check Softmax kernel availability based on size"""
360
361
362
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
363
364
365
366
367
368
369
370
371
372
373
374
375

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, scale_factor):
        """
        te_scaled_softmax_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
376
377
378
        out = SoftmaxPrimitive.backward_lowering(
            ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
379
380
381
382
383

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
384
385
386
        return SoftmaxPrimitive.backward_impl(
            ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor
        )
387
388
389
390

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
391
392
393
394
395
396
        return SoftmaxPrimitive.backward_batcher(
            ScaledSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
397
398
399
400

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
401
402
            scale_factor, mesh, arg_infos, result_infos
        )
403
404
405

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
406
407
408
        return ScaledSoftmaxBwdPrimitive.backward_partition(
            ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
409

410
411
412
413
414
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

415
416
417
418

register_primitive(ScaledSoftmaxBwdPrimitive)


419
def scaled_softmax_bwd(
420
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
421
) -> jnp.ndarray:
422
423
424
425
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
426
    if not ScaledSoftmaxBwdPrimitive.enabled():
427
        _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits)
428
429
        return vjp_func(dz)[0]

430
431
432
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
433
434
435
436
437
438


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
439

440
    name = "te_scaled_masked_softmax_forward_ffi"
441
    multiple_results = False
442
    impl_static_args = (2,)  # scale_factor
443
444
445
446
    inner_primitive = None
    outer_primitive = None

    @staticmethod
447
448
449
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
450
451
452
453
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
454
455
456
457
458
459
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
460
461
462
463
464
465
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
466
    def abstract(logits_aval, mask_aval, scale_factor):  # pylint: disable=unused-argument
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        """
        te_scaled_masked_softmax_forward abstract
        """

        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
        ]
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
488
489
        assert pad_batch in (1, batch)  # 1 means broadcast
        assert mask_shape[-3] == 1  # 1 means broadcast
490
491
492
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

493
        out_aval = logits_aval
494
495
496
497
498
499
500
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """
501
502
503
        return ffi.ffi_lowering(ScaledMaskedSoftmaxFwdPrimitive.name)(
            ctx, logits, mask, scale_factor=scale_factor
        )
504
505
506
507

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
508
509
510
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(
            logits, mask, scale_factor=scale_factor
        )
511
512
513
514
515
516
517
518
519
520
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
521
522
523
524
525
526
        return (
            ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
                logits, mask, scale_factor=scale_factor
            ),
            out_bdims,
        )
527
528
529
530

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
531
532
            scale_factor, mesh, arg_infos, result_infos
        )
533
534
535
536

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
537
538
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
539

540
541
542
543
544
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "...1, ...2 -> ...1"

545
546
547
548
549
550
551
552

register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
553

554
    name = "te_scaled_masked_softmax_backward_ffi"
555
    multiple_results = False
556
    impl_static_args = (2,)  # scale_factor
557
558
559
560
    inner_primitive = None
    outer_primitive = None

    @staticmethod
561
562
563
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
564
        """Check Softmax kernel availability based on size"""
565
566
567
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
568
569
570
571
572
573
574
575
576
577
578
579
580

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
581
        return SoftmaxPrimitive.backward_lowering(
582
583
            ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
584
585
586

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
587
588
589
590
591
592
        return SoftmaxPrimitive.backward_impl(
            ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
593
594
595
596

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
597
598
599
600
601
602
        return SoftmaxPrimitive.backward_batcher(
            ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
603
604
605
606

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
607
608
            scale_factor, mesh, arg_infos, result_infos
        )
609
610
611
612

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
613
614
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
615

616
617
618
619
620
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

621
622
623
624
625
626
627
628

register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
629

630
    name = "te_scaled_upper_triang_masked_softmax_forward_ffi"
631
    multiple_results = False
632
    impl_static_args = (1,)  # scale_factor
633
634
635
636
    inner_primitive = None
    outer_primitive = None

    @staticmethod
637
638
639
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
640
641
642
643
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
644
645
646
647
648
649
650
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
            and k_seqlen == q_seqlen
        ):
651
652
653
654
655
656
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
657
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
658
659
660
661
662
663
664
665
666
667
668
669
670
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
671
672
673
        return SoftmaxPrimitive.forward_lowering(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
674
675
676
677

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
678
679
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
680
681
682
683
684
685
686
687

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
688
689
            scale_factor=scale_factor,
        )
690
691
692
693

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
694
695
            scale_factor, mesh, arg_infos, result_infos
        )
696
697
698
699

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
700
701
702
703
704
705
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
706

707
708
709
710
711
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "... -> ..."

712
713
714
715
716
717
718
719

register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
720

721
    name = "te_scaled_upper_triang_masked_softmax_backward_ffi"
722
    multiple_results = False
723
    impl_static_args = (2,)  # scale_factor
724
725
726
727
    inner_primitive = None
    outer_primitive = None

    @staticmethod
728
729
730
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
731
732
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
733
734
            batch, heads, q_seqlen, k_seqlen, dtype
        )
735
736
737
738
739
740
741
742
743
744
745
746
747

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
748
        return SoftmaxPrimitive.backward_lowering(
749
750
751
752
753
754
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
            ctx,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
755
756
757
758
759
760
761

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
762
763
            scale_factor=scale_factor,
        )
764
765
766
767
768
769
770
771

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
772
773
            scale_factor=scale_factor,
        )
774
775
776
777

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
778
779
            scale_factor, mesh, arg_infos, result_infos
        )
780
781
782
783

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
784
785
786
787
788
789
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
790

791
792
793
794
795
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

796
797
798
799

register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


800
801
802
def jax_scaled_softmax(
    logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
803
804
805
    """
    JAX based implementation of scaled softmax
    """
806
807
    if softmax_offset is not None:
        return jax_general_softmax(scale_factor * logits, offset=softmax_offset)
808
809
810
    return jax.nn.softmax(scale_factor * logits)


811
812
813
814
815
816
def jax_scaled_masked_softmax(
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
    softmax_offset: jnp.ndarray | float | None = None,
):
817
818
819
    """
    JAX based implementation of scaled and masked softmax
    """
820
821
    if softmax_offset is not None:
        return jax_general_softmax(logits * scale_factor, offset=softmax_offset, where=mask != 1)
822
    return jax.nn.softmax(logits * scale_factor, where=mask != 1)
823
824


825
826
827
def jax_scaled_upper_triang_masked_softmax(
    logits: jnp.ndarray, scale_factor: float, softmax_offset: jnp.ndarray | float | None = None
):
828
829
830
    """
    JAX based implementation of scaled and upper triangle masked softmax
    """
831
    mask = 1 - jnp.tril(jnp.ones_like(logits))
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
    return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset)


def jax_general_softmax(
    x: jnp.ndarray,
    axis: int = -1,
    where: jnp.ndarray | None = None,
    initial: jnp.ndarray = -jnp.inf,
    offset: jnp.ndarray | float | None = None,
) -> jnp.ndarray:
    """
    JAX based implementation of general softmax with optional masking and offset.
    """
    # Compute max of x
    x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)

    if offset is not None:
        # Cast offset to x.dtype to prevent type promotion
        if isinstance(offset, (int, float)):
            offset = jnp.array(offset, dtype=x.dtype)
        else:
            offset = offset.astype(x.dtype)

        # Include offset in max: x_max = max(x_max, offset)
        # This is equivalent to computing max over [x..., offset]
        x_max = jnp.maximum(x_max, offset)

    x_safe = x if where is None else jnp.where(where, x, initial)
    unnormalized = jnp.exp(x_safe - x_max)
    denominator = jnp.sum(unnormalized, axis, where=where, keepdims=True)

    if offset is not None:
        # Add exp(offset - x_max) to denominator
        denominator = denominator + jnp.exp(offset - x_max)

    result = unnormalized / denominator
    if where is not None:
        result = jnp.where(where, result, 0)
    return result
871
872
873
874
875
876
877
878


def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledSoftmaxFwdPrimitive.enabled():
879
        return jax_scaled_softmax(logits, scale_factor)
880
881
882
883
884
885
886
887
888
889
890
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)


def scaled_masked_softmax_fwd(
    logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
891
        return jax_scaled_masked_softmax(logits, mask, scale_factor)
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, mask, scale_factor=scale_factor
    )


def scaled_masked_softmax_bwd(
    dz: jnp.ndarray,
    softmax_out: jnp.ndarray,
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
910
            partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
911
912
913
914
915
916
917
918
919
920
921
922
923
        )
        return vjp_func(dz)[0]
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
924
        return jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
925
926
927
928
929
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor
    )


930
def scaled_upper_triang_masked_softmax_bwd(
931
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
932
) -> jnp.ndarray:
933
934
935
936
    """
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
    """
937
938
    if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
939
            partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
940
941
        )
        return vjp_func(dz)[0]
942
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
943
944
        dz, softmax_out, scale_factor=scale_factor
    )