softmax.py 29.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
#
# See LICENSE for license information.
"""JAX/TE custom ops for softmax"""
from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings

10
import jax
11
import jax.numpy as jnp
12
from jax import dtypes, ffi
13
14
15
from jax.sharding import PartitionSpec, NamedSharding

from .base import BasePrimitive, register_primitive
16
from .misc import get_padded_spec, check_valid_batch_dims
17
18
19
from ..softmax import SoftmaxType


20
21
22
23
24
25
26
27
__all__ = [
    "scaled_softmax_fwd",
    "scaled_softmax_bwd",
    "scaled_masked_softmax_fwd",
    "scaled_masked_softmax_bwd",
    "scaled_upper_triang_masked_softmax_fwd",
    "scaled_upper_triang_masked_softmax_bwd",
    "is_softmax_kernel_available",
28
29
30
    "jax_scaled_softmax",
    "jax_scaled_masked_softmax",
    "jax_scaled_upper_triang_masked_softmax",
31
32
33
34
35
36
37
38
39
40
41
]


def is_softmax_kernel_available(
    softmax_type: SoftmaxType,
    batch: int,
    heads: int,
    q_seqlen: int,
    k_seqlen: int,
    dtype: jnp.dtype,
):
42
43
    """check softmax available"""
    if softmax_type is SoftmaxType.SCALED:
44
45
46
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
47
    if softmax_type is SoftmaxType.SCALED_MASKED:
48
49
50
        return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
51
52
    if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
53
54
            batch, heads, q_seqlen, k_seqlen, dtype
        )
55
56
57
58
59
60
61
62

    raise NotImplementedError


class SoftmaxPrimitive(BasePrimitive):
    """
    Softmax Primitive
    """
63

64
65
66
67
68
    max_k_seqlen_supported = 16384
    name = "te_softmax_internal_placeholder"

    @staticmethod
    @abstractmethod
69
70
71
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
72
73
74
75
76
77
78
        """Check Softmax kernel availability based on size"""
        raise NotImplementedError

    @staticmethod
    def get_batch_per_block(k_seqlen: int) -> int:
        """Get batch per CTA in Softmax kernels"""
        threads_per_warp = 32
79
        threads_per_block = 128  # Depends on the kernel implmentation
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

        pow2 = 1 << (k_seqlen - 1).bit_length()
        warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
        batches_per_warp = 2 if pow2 <= 128 else 1
        warps_per_block = threads_per_block // warp_size
        batches_per_block = warps_per_block * batches_per_warp
        return batches_per_block

    @staticmethod
    def forward_abstract(logits_aval, scale_factor):
        """
        softmax_forward abstract
        """
        del scale_factor
        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape
        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

103
        out_aval = logits_aval
104
105
106
107
108
109
110
        return out_aval

    @staticmethod
    def forward_lowering(name, ctx, logits, *, scale_factor):
        """
        softmax_forward lowering rules
        """
111
        return ffi.ffi_lowering(name)(ctx, logits, scale_factor=scale_factor)
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    @staticmethod
    def forward_impl(primitive, logits, scale_factor):
        """
        softmax_forward implementation
        """
        assert primitive is not None
        output = primitive.bind(logits, scale_factor=scale_factor)
        return output

    @staticmethod
    def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_forward batcher
        """
        assert primitive is not None
128
129
        (logits,) = batched_args
        (logits_bdim,) = batch_dims
130
131
132
133
134
135
136
137
138

        out_bdims = logits_bdim
        return primitive.bind(logits, scale_factor=scale_factor), out_bdims

    @classmethod
    def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward infer_sharding_from_operands
        """
139
        del scale_factor, result_infos  # Unused.
140
141
142
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
143
144
145
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
146
147
148
149
150
151
152
153
154
155
156
157
158
            )
        out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        return out_sharding

    @classmethod
    def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_forward partitioning
        """
        del result_infos
        logits_spec = get_padded_spec(arg_infos[0])
        if logits_spec[-1] is not None:
            warnings.warn(
159
160
161
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
162
163
164
165
166
167
168
            )
        out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
        arg_shardings = (out_shardings,)
        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings

    @staticmethod
169
170
171
    def backward_abstract(
        dz_aval, softmax_out_aval, scale_factor=None
    ):  # pylint: disable=unused-argument
172
173
174
175
176
177
178
179
180
181
182
        """
        softmax_backward abstract
        """
        dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
        softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype)
        assert dz_dtype == softmax_out_dtype
        assert dz_dtype in [jnp.float16, jnp.bfloat16]
        assert softmax_out_dtype in [jnp.float16, jnp.bfloat16]

        assert dz_aval.shape == softmax_out_aval.shape

183
        dx_aval = dz_aval
184
185
186
187
188
189
190
        return dx_aval

    @staticmethod
    def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
        """
        softmax_backward lowering rules
        """
191
        return ffi.ffi_lowering(name)(ctx, dz, softmax_out, scale_factor=scale_factor)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    @staticmethod
    def backward_impl(primitive, dz, softmax_out, scale_factor):
        """
        softmax_backward implementation
        """
        assert primitive is not None
        dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor)
        return dx

    @staticmethod
    def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
        """
        softmax_backward batcher
        """
        assert primitive is not None
        dz, softmax_out = batched_args
        _, softmax_out_bdim = batch_dims

        out_bdims = softmax_out_bdim
        return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

    @classmethod
    def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward infer_sharding_from_operands
        """
219
        del scale_factor, result_infos  # Unused.
220
221
222
        dz_spec = get_padded_spec(arg_infos[0])
        if dz_spec[-1] is not None:
            warnings.warn(
223
224
225
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            )
        dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        return dx_sharding

    @classmethod
    def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
        """
        softmax_backward partition
        """
        del result_infos

        dz_spec = get_padded_spec(arg_infos[0])
        softmax_out_spec = get_padded_spec(arg_infos[1])
        if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
            warnings.warn(
241
242
243
                f"Sharding the hidden dimension is not supported in {cls.name}! "
                "Forcing XLA to not shard the hidden dim, which might introduce extra "
                "collective ops and hurt performance."
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            )

        dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
        softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
        dx_sharding = dz_sharding
        arg_shardings = (dz_sharding, softmax_out_sharding)
        out_shardings = dx_sharding

        impl = partial(impl, scale_factor=scale_factor)
        return mesh, impl, out_shardings, arg_shardings


class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Fwd Primitive
    """
260

261
    name = "te_scaled_softmax_forward_ffi"
262
    multiple_results = False
263
    impl_static_args = (1,)  # scale_factor
264
265
266
267
    inner_primitive = None
    outer_primitive = None

    @staticmethod
268
269
270
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
271
272
273
274
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
275
276
277
278
279
280
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
281
282
283
284
285
286
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
287
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
288
289
290
291
292
293
294
295
296
297
        """
        te_scaled_softmax_forward abstract
        """
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_softmax_forward lowering rules
        """
298
299
300
        return SoftmaxPrimitive.forward_lowering(
            ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
301
302
303

    @staticmethod
    def impl(logits, scale_factor):
304
305
306
        return SoftmaxPrimitive.forward_impl(
            ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
307
308
309
310

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
311
312
313
314
315
316
        return SoftmaxPrimitive.forward_batcher(
            ScaledSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
317
318
319
320

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
321
322
            scale_factor, mesh, arg_infos, result_infos
        )
323
324
325

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
326
327
328
        return ScaledSoftmaxFwdPrimitive.forward_partition(
            ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
329

330
331
332
333
334
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "... -> ..."

335
336
337
338
339
340
341
342

register_primitive(ScaledSoftmaxFwdPrimitive)


class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Softmax Bwd Primitive
    """
343

344
    name = "te_scaled_softmax_backward_ffi"
345
    multiple_results = False
346
    impl_static_args = (2,)  # scale_factor
347
348
349
350
    inner_primitive = None
    outer_primitive = None

    @staticmethod
351
352
353
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
354
        """Check Softmax kernel availability based on size"""
355
356
357
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
358
359
360
361
362
363
364
365
366
367
368
369
370

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, scale_factor):
        """
        te_scaled_softmax_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_softmax_backward lowering rules
        """
371
372
373
        out = SoftmaxPrimitive.backward_lowering(
            ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
374
375
376
377
378

        return out

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
379
380
381
        return SoftmaxPrimitive.backward_impl(
            ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor
        )
382
383
384
385

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
386
387
388
389
390
391
        return SoftmaxPrimitive.backward_batcher(
            ScaledSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
392
393
394
395

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
396
397
            scale_factor, mesh, arg_infos, result_infos
        )
398
399
400

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
401
402
403
        return ScaledSoftmaxBwdPrimitive.backward_partition(
            ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
404

405
406
407
408
409
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

410
411
412
413

register_primitive(ScaledSoftmaxBwdPrimitive)


414
def scaled_softmax_bwd(
415
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
416
) -> jnp.ndarray:
417
418
419
420
    """
    scaled_backward wrapper
    Return FP16/BF16 tensor
    """
421
    if not ScaledSoftmaxBwdPrimitive.enabled():
422
        _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits)
423
424
        return vjp_func(dz)[0]

425
426
427
    return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )
428
429
430
431
432
433


class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Fwd Primitive
    """
434

435
    name = "te_scaled_masked_softmax_forward_ffi"
436
    multiple_results = False
437
    impl_static_args = (2,)  # scale_factor
438
439
440
441
    inner_primitive = None
    outer_primitive = None

    @staticmethod
442
443
444
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
445
446
447
448
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
449
450
451
452
453
454
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
        ):
455
456
457
458
459
460
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return q_seqlen % batch_per_block == 0
        return False

    @staticmethod
461
    def abstract(logits_aval, mask_aval, scale_factor):  # pylint: disable=unused-argument
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        """
        te_scaled_masked_softmax_forward abstract
        """

        i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
        assert i_dtype in [jnp.float16, jnp.bfloat16]
        i_shape = logits_aval.shape

        # Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
        batch = reduce(operator.mul, i_shape[:-3])
        q_seqlen = i_shape[-2]
        k_seqlen = i_shape[-1]
        assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
        assert q_seqlen > 1

        mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype)
        assert mask_dtype in [
            jnp.uint8,
        ]
        mask_shape = mask_aval.shape
        pad_batch = batch = reduce(operator.mul, mask_shape[:-3])
483
484
        assert pad_batch in (1, batch)  # 1 means broadcast
        assert mask_shape[-3] == 1  # 1 means broadcast
485
486
487
        assert mask_shape[-2] == q_seqlen
        assert mask_shape[-1] == k_seqlen

488
        out_aval = logits_aval
489
490
491
492
493
494
495
        return out_aval

    @staticmethod
    def lowering(ctx, logits, mask, *, scale_factor):
        """
        te_scaled_masked_softmax_forward lowering rules
        """
496
497
498
        return ffi.ffi_lowering(ScaledMaskedSoftmaxFwdPrimitive.name)(
            ctx, logits, mask, scale_factor=scale_factor
        )
499
500
501
502

    @staticmethod
    def impl(logits, mask, scale_factor):
        assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None
503
504
505
        output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind(
            logits, mask, scale_factor=scale_factor
        )
506
507
508
509
510
511
512
513
514
515
        return output

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None
        logits, mask = batched_args
        logits_bdim, _ = batch_dims

        out_bdims = logits_bdim
516
517
518
519
520
521
        return (
            ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
                logits, mask, scale_factor=scale_factor
            ),
            out_bdims,
        )
522
523
524
525

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
526
527
            scale_factor, mesh, arg_infos, result_infos
        )
528
529
530
531

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
532
533
            ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
534

535
536
537
538
539
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "...1, ...2 -> ...1"

540
541
542
543
544
545
546
547

register_primitive(ScaledMaskedSoftmaxFwdPrimitive)


class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Masked Softmax Bwd Primitive
    """
548

549
    name = "te_scaled_masked_softmax_backward_ffi"
550
    multiple_results = False
551
    impl_static_args = (2,)  # scale_factor
552
553
554
555
    inner_primitive = None
    outer_primitive = None

    @staticmethod
556
557
558
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
559
        """Check Softmax kernel availability based on size"""
560
561
562
        return ScaledSoftmaxFwdPrimitive.is_kernel_available(
            batch, heads, q_seqlen, k_seqlen, dtype
        )
563
564
565
566
567
568
569
570
571
572
573
574
575

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
576
        return SoftmaxPrimitive.backward_lowering(
577
578
            ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor
        )
579
580
581

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
582
583
584
585
586
587
        return SoftmaxPrimitive.backward_impl(
            ScaledMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
588
589
590
591

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
592
593
594
595
596
597
        return SoftmaxPrimitive.backward_batcher(
            ScaledMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
            scale_factor=scale_factor,
        )
598
599
600
601

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
602
603
            scale_factor, mesh, arg_infos, result_infos
        )
604
605
606
607

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
608
609
            ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
        )
610

611
612
613
614
615
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

616
617
618
619
620
621
622
623

register_primitive(ScaledMaskedSoftmaxBwdPrimitive)


class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Fwd Primitive
    """
624

625
    name = "te_scaled_upper_triang_masked_softmax_forward_ffi"
626
    multiple_results = False
627
    impl_static_args = (1,)  # scale_factor
628
629
630
631
    inner_primitive = None
    outer_primitive = None

    @staticmethod
632
633
634
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
635
636
637
638
        """Check Softmax kernel availability based on size"""
        attn_batches = batch * heads

        dtype = dtypes.canonicalize_dtype(dtype)
639
640
641
642
643
644
645
        if (
            dtype in [jnp.float16, jnp.bfloat16]
            and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
            and q_seqlen % 4 == 0  # q_seqlen must be divisor of 4
            and attn_batches % 4 == 0  # batch * heads must be divisor of 4
            and k_seqlen == q_seqlen
        ):
646
647
648
649
650
651
            if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
                batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
                return attn_batches % batch_per_block == 0
        return False

    @staticmethod
652
    def abstract(logits_aval, scale_factor):  # pylint: disable=unused-argument
653
654
655
656
657
658
659
660
661
662
663
664
665
        """
        te_scaled_upper_triang_masked_softmax_forward abstract
        """
        q_seqlen = logits_aval.shape[-2]
        k_seqlen = logits_aval.shape[-1]
        assert q_seqlen == k_seqlen
        return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor)

    @staticmethod
    def lowering(ctx, logits, *, scale_factor):
        """
        te_scaled_upper_triang_masked_softmax_forward lowering rules
        """
666
667
668
        return SoftmaxPrimitive.forward_lowering(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor
        )
669
670
671
672

    @staticmethod
    def impl(logits, scale_factor):
        return SoftmaxPrimitive.forward_impl(
673
674
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor
        )
675
676
677
678
679
680
681
682

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.forward_batcher(
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
683
684
            scale_factor=scale_factor,
        )
685
686
687
688

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
689
690
            scale_factor, mesh, arg_infos, result_infos
        )
691
692
693
694

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
695
696
697
698
699
700
            ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
701

702
703
704
705
706
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "... -> ..."

707
708
709
710
711
712
713
714

register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)


class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
    """
    Scaled Upper Triang Masked Softmax Bwd Primitive
    """
715

716
    name = "te_scaled_upper_triang_masked_softmax_backward_ffi"
717
    multiple_results = False
718
    impl_static_args = (2,)  # scale_factor
719
720
721
722
    inner_primitive = None
    outer_primitive = None

    @staticmethod
723
724
725
    def is_kernel_available(
        batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype
    ) -> bool:
726
727
        """Check Softmax kernel availability based on size"""
        return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
728
729
            batch, heads, q_seqlen, k_seqlen, dtype
        )
730
731
732
733
734
735
736
737
738
739
740
741
742

    @staticmethod
    def abstract(dz_aval, softmax_out_aval, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward abstract
        """
        return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor)

    @staticmethod
    def lowering(ctx, dz, softmax_out, *, scale_factor):
        """
        te_scaled_upper_triang_masked_backward lowering rules
        """
743
        return SoftmaxPrimitive.backward_lowering(
744
745
746
747
748
749
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name,
            ctx,
            dz,
            softmax_out,
            scale_factor=scale_factor,
        )
750
751
752
753
754
755
756

    @staticmethod
    def impl(dz, softmax_out, scale_factor):
        return SoftmaxPrimitive.backward_impl(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive,
            dz,
            softmax_out,
757
758
            scale_factor=scale_factor,
        )
759
760
761
762
763
764
765
766

    @staticmethod
    def batcher(batched_args, batch_dims, *, scale_factor):
        check_valid_batch_dims(batch_dims)
        return SoftmaxPrimitive.backward_batcher(
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive,
            batched_args,
            batch_dims,
767
768
            scale_factor=scale_factor,
        )
769
770
771
772

    @staticmethod
    def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
773
774
            scale_factor, mesh, arg_infos, result_infos
        )
775
776
777
778

    @staticmethod
    def partition(scale_factor, mesh, arg_infos, result_infos):
        return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
779
780
781
782
783
784
            ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
            scale_factor,
            mesh,
            arg_infos,
            result_infos,
        )
785

786
787
788
789
790
    @staticmethod
    def shardy_sharding_rule(*args):
        del args
        return "..., ... -> ..."

791
792
793
794

register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)


795
796
797
798
def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
    """
    JAX based implementation of scaled softmax
    """
799
800
801
    return jax.nn.softmax(scale_factor * logits)


802
803
804
805
def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
    """
    JAX based implementation of scaled and masked softmax
    """
806
    return jax.nn.softmax(logits * scale_factor, where=mask != 1)
807
808


809
810
811
812
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
    """
    JAX based implementation of scaled and upper triangle masked softmax
    """
813
    mask = 1 - jnp.tril(jnp.ones_like(logits))
814
    return jax_scaled_masked_softmax(logits, mask, scale_factor)
815
816
817
818
819
820
821
822


def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledSoftmaxFwdPrimitive.enabled():
823
        return jax_scaled_softmax(logits, scale_factor)
824
825
826
827
828
829
830
831
832
833
834
    return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)


def scaled_masked_softmax_fwd(
    logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float
) -> jnp.ndarray:
    """
    scaled_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
835
        return jax_scaled_masked_softmax(logits, mask, scale_factor)
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
    return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, mask, scale_factor=scale_factor
    )


def scaled_masked_softmax_bwd(
    dz: jnp.ndarray,
    softmax_out: jnp.ndarray,
    logits: jnp.ndarray,
    mask: jnp.ndarray,
    scale_factor: float,
) -> jnp.ndarray:
    """
    scaled_masked_backward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
854
            partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
855
856
857
858
859
860
861
862
863
864
865
866
867
        )
        return vjp_func(dz)[0]
    return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
        dz, softmax_out, scale_factor=scale_factor
    )


def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
    """
    scaled_upper_triang_masked_softmax_forward wrapper
    Return FP16/BF16 tensor
    """
    if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
868
        return jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
869
870
871
872
873
    return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
        logits, scale_factor=scale_factor
    )


874
def scaled_upper_triang_masked_softmax_bwd(
875
    dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float
876
) -> jnp.ndarray:
877
878
879
880
    """
    scaled_upper_triang_masked_backward wrapper
    Return FP16/BF16 tensor
    """
881
882
    if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
        _, vjp_func = jax.vjp(
883
            partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
884
885
        )
        return vjp_func(dz)[0]
886
    return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
887
888
        dz, softmax_out, scale_factor=scale_factor
    )