# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for softmax""" from abc import abstractmethod from functools import partial, reduce import operator import warnings from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.sharding import PartitionSpec, NamedSharding from .base import BasePrimitive, register_primitive from .misc import get_padded_spec, check_valid_batch_dims from ..softmax import SoftmaxType if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax import ffi # pylint: disable=ungrouped-imports else: from jax.extend import ffi # pylint: disable=ungrouped-imports __all__ = [ "scaled_softmax_fwd", "scaled_softmax_bwd", "scaled_masked_softmax_fwd", "scaled_masked_softmax_bwd", "scaled_upper_triang_masked_softmax_fwd", "scaled_upper_triang_masked_softmax_bwd", "is_softmax_kernel_available", "jax_scaled_softmax", "jax_scaled_masked_softmax", "jax_scaled_upper_triang_masked_softmax", ] def is_softmax_kernel_available( softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype, ): """check softmax available""" if softmax_type is SoftmaxType.SCALED: return ScaledSoftmaxFwdPrimitive.is_kernel_available( batch, heads, q_seqlen, k_seqlen, dtype ) if softmax_type is SoftmaxType.SCALED_MASKED: return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available( batch, heads, q_seqlen, k_seqlen, dtype ) if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( batch, heads, q_seqlen, k_seqlen, dtype ) raise NotImplementedError class SoftmaxPrimitive(BasePrimitive): """ Softmax Primitive """ max_k_seqlen_supported = 16384 name = "te_softmax_internal_placeholder" @staticmethod @abstractmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" raise NotImplementedError @staticmethod def get_batch_per_block(k_seqlen: int) -> int: """Get batch per CTA in Softmax kernels""" threads_per_warp = 32 threads_per_block = 128 # Depends on the kernel implmentation pow2 = 1 << (k_seqlen - 1).bit_length() warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp batches_per_warp = 2 if pow2 <= 128 else 1 warps_per_block = threads_per_block // warp_size batches_per_block = warps_per_block * batches_per_warp return batches_per_block @staticmethod def forward_abstract(logits_aval, scale_factor): """ softmax_forward abstract """ del scale_factor i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) assert i_dtype in [jnp.float16, jnp.bfloat16] i_shape = logits_aval.shape # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] q_seqlen = i_shape[-2] k_seqlen = i_shape[-1] assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 out_aval = logits_aval return out_aval @staticmethod def forward_lowering(name, ctx, logits, *, scale_factor): """ softmax_forward lowering rules """ return ffi.ffi_lowering(name)(ctx, logits, scale_factor=scale_factor) @staticmethod def forward_impl(primitive, logits, scale_factor): """ softmax_forward implementation """ assert primitive is not None output = primitive.bind(logits, scale_factor=scale_factor) return output @staticmethod def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor): """ softmax_forward batcher """ assert primitive is not None (logits,) = batched_args (logits_bdim,) = batch_dims out_bdims = logits_bdim return primitive.bind(logits, scale_factor=scale_factor), out_bdims @classmethod def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos): """ softmax_forward infer_sharding_from_operands """ del scale_factor, result_infos # Unused. logits_spec = get_padded_spec(arg_infos[0]) if logits_spec[-1] is not None: warnings.warn( f"Sharding the hidden dimension is not supported in {cls.name}! " "Forcing XLA to not shard the hidden dim, which might introduce extra " "collective ops and hurt performance." ) out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) return out_sharding @classmethod def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos): """ softmax_forward partitioning """ del result_infos logits_spec = get_padded_spec(arg_infos[0]) if logits_spec[-1] is not None: warnings.warn( f"Sharding the hidden dimension is not supported in {cls.name}! " "Forcing XLA to not shard the hidden dim, which might introduce extra " "collective ops and hurt performance." ) out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None)) arg_shardings = (out_shardings,) impl = partial(impl, scale_factor=scale_factor) return mesh, impl, out_shardings, arg_shardings @staticmethod def backward_abstract( dz_aval, softmax_out_aval, scale_factor=None ): # pylint: disable=unused-argument """ softmax_backward abstract """ dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) softmax_out_dtype = dtypes.canonicalize_dtype(softmax_out_aval.dtype) assert dz_dtype == softmax_out_dtype assert dz_dtype in [jnp.float16, jnp.bfloat16] assert softmax_out_dtype in [jnp.float16, jnp.bfloat16] assert dz_aval.shape == softmax_out_aval.shape dx_aval = dz_aval return dx_aval @staticmethod def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor): """ softmax_backward lowering rules """ return ffi.ffi_lowering(name)(ctx, dz, softmax_out, scale_factor=scale_factor) @staticmethod def backward_impl(primitive, dz, softmax_out, scale_factor): """ softmax_backward implementation """ assert primitive is not None dx = primitive.bind(dz, softmax_out, scale_factor=scale_factor) return dx @staticmethod def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor): """ softmax_backward batcher """ assert primitive is not None dz, softmax_out = batched_args _, softmax_out_bdim = batch_dims out_bdims = softmax_out_bdim return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims @classmethod def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos): """ softmax_backward infer_sharding_from_operands """ del scale_factor, result_infos # Unused. dz_spec = get_padded_spec(arg_infos[0]) if dz_spec[-1] is not None: warnings.warn( f"Sharding the hidden dimension is not supported in {cls.name}! " "Forcing XLA to not shard the hidden dim, which might introduce extra " "collective ops and hurt performance." ) dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None)) return dx_sharding @classmethod def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos): """ softmax_backward partition """ del result_infos dz_spec = get_padded_spec(arg_infos[0]) softmax_out_spec = get_padded_spec(arg_infos[1]) if dz_spec[-1] is not None or softmax_out_spec[-1] is not None: warnings.warn( f"Sharding the hidden dimension is not supported in {cls.name}! " "Forcing XLA to not shard the hidden dim, which might introduce extra " "collective ops and hurt performance." ) dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None)) softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None)) dx_sharding = dz_sharding arg_shardings = (dz_sharding, softmax_out_sharding) out_shardings = dx_sharding impl = partial(impl, scale_factor=scale_factor) return mesh, impl, out_shardings, arg_shardings class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Softmax Fwd Primitive """ name = "te_scaled_softmax_forward_ffi" multiple_results = False impl_static_args = (1,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" attn_batches = batch * heads dtype = dtypes.canonicalize_dtype(dtype) if ( dtype in [jnp.float16, jnp.bfloat16] and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 and attn_batches % 4 == 0 # batch * heads must be divisor of 4 ): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return q_seqlen % batch_per_block == 0 return False @staticmethod def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument """ te_scaled_softmax_forward abstract """ return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor) @staticmethod def lowering(ctx, logits, *, scale_factor): """ te_scaled_softmax_forward lowering rules """ return SoftmaxPrimitive.forward_lowering( ScaledSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor ) @staticmethod def impl(logits, scale_factor): return SoftmaxPrimitive.forward_impl( ScaledSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) return SoftmaxPrimitive.forward_batcher( ScaledSoftmaxFwdPrimitive.outer_primitive, batched_args, batch_dims, scale_factor=scale_factor, ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( scale_factor, mesh, arg_infos, result_infos ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxFwdPrimitive.forward_partition( ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) @staticmethod def shardy_sharding_rule(*args): del args return "... -> ..." register_primitive(ScaledSoftmaxFwdPrimitive) class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): """ Scaled Softmax Bwd Primitive """ name = "te_scaled_softmax_backward_ffi" multiple_results = False impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" return ScaledSoftmaxFwdPrimitive.is_kernel_available( batch, heads, q_seqlen, k_seqlen, dtype ) @staticmethod def abstract(dz_aval, softmax_out_aval, scale_factor): """ te_scaled_softmax_backward abstract """ return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor) @staticmethod def lowering(ctx, dz, softmax_out, *, scale_factor): """ te_scaled_softmax_backward lowering rules """ out = SoftmaxPrimitive.backward_lowering( ScaledSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor ) return out @staticmethod def impl(dz, softmax_out, scale_factor): return SoftmaxPrimitive.backward_impl( ScaledSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) return SoftmaxPrimitive.backward_batcher( ScaledSoftmaxBwdPrimitive.outer_primitive, batched_args, batch_dims, scale_factor=scale_factor, ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( scale_factor, mesh, arg_infos, result_infos ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxBwdPrimitive.backward_partition( ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) @staticmethod def shardy_sharding_rule(*args): del args return "..., ... -> ..." register_primitive(ScaledSoftmaxBwdPrimitive) def scaled_softmax_bwd( dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_backward wrapper Return FP16/BF16 tensor """ if not ScaledSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits) return vjp_func(dz)[0] return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Masked Softmax Fwd Primitive """ name = "te_scaled_masked_softmax_forward_ffi" multiple_results = False impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" attn_batches = batch * heads dtype = dtypes.canonicalize_dtype(dtype) if ( dtype in [jnp.float16, jnp.bfloat16] and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 and attn_batches % 4 == 0 # batch * heads must be divisor of 4 ): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return q_seqlen % batch_per_block == 0 return False @staticmethod def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-argument """ te_scaled_masked_softmax_forward abstract """ i_dtype = dtypes.canonicalize_dtype(logits_aval.dtype) assert i_dtype in [jnp.float16, jnp.bfloat16] i_shape = logits_aval.shape # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] batch = reduce(operator.mul, i_shape[:-3]) q_seqlen = i_shape[-2] k_seqlen = i_shape[-1] assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 mask_dtype = dtypes.canonicalize_dtype(mask_aval.dtype) assert mask_dtype in [ jnp.uint8, ] mask_shape = mask_aval.shape pad_batch = batch = reduce(operator.mul, mask_shape[:-3]) assert pad_batch in (1, batch) # 1 means broadcast assert mask_shape[-3] == 1 # 1 means broadcast assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen out_aval = logits_aval return out_aval @staticmethod def lowering(ctx, logits, mask, *, scale_factor): """ te_scaled_masked_softmax_forward lowering rules """ return ffi.ffi_lowering(ScaledMaskedSoftmaxFwdPrimitive.name)( ctx, logits, mask, scale_factor=scale_factor ) @staticmethod def impl(logits, mask, scale_factor): assert ScaledMaskedSoftmaxFwdPrimitive.inner_primitive is not None output = ScaledMaskedSoftmaxFwdPrimitive.inner_primitive.bind( logits, mask, scale_factor=scale_factor ) return output @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) assert ScaledMaskedSoftmaxFwdPrimitive.outer_primitive is not None logits, mask = batched_args logits_bdim, _ = batch_dims out_bdims = logits_bdim return ( ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, mask, scale_factor=scale_factor ), out_bdims, ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( scale_factor, mesh, arg_infos, result_infos ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.backward_partition( ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) @staticmethod def shardy_sharding_rule(*args): del args return "...1, ...2 -> ...1" register_primitive(ScaledMaskedSoftmaxFwdPrimitive) class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): """ Scaled Masked Softmax Bwd Primitive """ name = "te_scaled_masked_softmax_backward_ffi" multiple_results = False impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" return ScaledSoftmaxFwdPrimitive.is_kernel_available( batch, heads, q_seqlen, k_seqlen, dtype ) @staticmethod def abstract(dz_aval, softmax_out_aval, *, scale_factor): """ te_scaled_upper_triang_masked_backward abstract """ return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor) @staticmethod def lowering(ctx, dz, softmax_out, *, scale_factor): """ te_scaled_upper_triang_masked_backward lowering rules """ return SoftmaxPrimitive.backward_lowering( ScaledMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor ) @staticmethod def impl(dz, softmax_out, scale_factor): return SoftmaxPrimitive.backward_impl( ScaledMaskedSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor, ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) return SoftmaxPrimitive.backward_batcher( ScaledMaskedSoftmaxBwdPrimitive.outer_primitive, batched_args, batch_dims, scale_factor=scale_factor, ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( scale_factor, mesh, arg_infos, result_infos ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_partition( ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ) @staticmethod def shardy_sharding_rule(*args): del args return "..., ... -> ..." register_primitive(ScaledMaskedSoftmaxBwdPrimitive) class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Upper Triang Masked Softmax Fwd Primitive """ name = "te_scaled_upper_triang_masked_softmax_forward_ffi" multiple_results = False impl_static_args = (1,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" attn_batches = batch * heads dtype = dtypes.canonicalize_dtype(dtype) if ( dtype in [jnp.float16, jnp.bfloat16] and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4 and attn_batches % 4 == 0 # batch * heads must be divisor of 4 and k_seqlen == q_seqlen ): if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported: batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen) return attn_batches % batch_per_block == 0 return False @staticmethod def abstract(logits_aval, scale_factor): # pylint: disable=unused-argument """ te_scaled_upper_triang_masked_softmax_forward abstract """ q_seqlen = logits_aval.shape[-2] k_seqlen = logits_aval.shape[-1] assert q_seqlen == k_seqlen return SoftmaxPrimitive.forward_abstract(logits_aval, scale_factor) @staticmethod def lowering(ctx, logits, *, scale_factor): """ te_scaled_upper_triang_masked_softmax_forward lowering rules """ return SoftmaxPrimitive.forward_lowering( ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, ctx, logits, scale_factor=scale_factor ) @staticmethod def impl(logits, scale_factor): return SoftmaxPrimitive.forward_impl( ScaledUpperTriangMaskedSoftmaxFwdPrimitive.inner_primitive, logits, scale_factor ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) return SoftmaxPrimitive.forward_batcher( ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive, batched_args, batch_dims, scale_factor=scale_factor, ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( scale_factor, mesh, arg_infos, result_infos ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition( ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos, ) @staticmethod def shardy_sharding_rule(*args): del args return "... -> ..." register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): """ Scaled Upper Triang Masked Softmax Bwd Primitive """ name = "te_scaled_upper_triang_masked_softmax_backward_ffi" multiple_results = False impl_static_args = (2,) # scale_factor inner_primitive = None outer_primitive = None @staticmethod def is_kernel_available( batch: int, heads: int, q_seqlen: int, k_seqlen: int, dtype: jnp.dtype ) -> bool: """Check Softmax kernel availability based on size""" return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available( batch, heads, q_seqlen, k_seqlen, dtype ) @staticmethod def abstract(dz_aval, softmax_out_aval, *, scale_factor): """ te_scaled_upper_triang_masked_backward abstract """ return SoftmaxPrimitive.backward_abstract(dz_aval, softmax_out_aval, scale_factor) @staticmethod def lowering(ctx, dz, softmax_out, *, scale_factor): """ te_scaled_upper_triang_masked_backward lowering rules """ return SoftmaxPrimitive.backward_lowering( ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, dz, softmax_out, scale_factor=scale_factor, ) @staticmethod def impl(dz, softmax_out, scale_factor): return SoftmaxPrimitive.backward_impl( ScaledUpperTriangMaskedSoftmaxBwdPrimitive.inner_primitive, dz, softmax_out, scale_factor=scale_factor, ) @staticmethod def batcher(batched_args, batch_dims, *, scale_factor): check_valid_batch_dims(batch_dims) return SoftmaxPrimitive.backward_batcher( ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive, batched_args, batch_dims, scale_factor=scale_factor, ) @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( scale_factor, mesh, arg_infos, result_infos ) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition( ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos, ) @staticmethod def shardy_sharding_rule(*args): del args return "..., ... -> ..." register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): """ JAX based implementation of scaled softmax """ return jax.nn.softmax(scale_factor * logits) def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): """ JAX based implementation of scaled and masked softmax """ return jax.nn.softmax(logits * scale_factor, where=mask != 1) def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): """ JAX based implementation of scaled and upper triangle masked softmax """ mask = 1 - jnp.tril(jnp.ones_like(logits)) return jax_scaled_masked_softmax(logits, mask, scale_factor) def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: """ scaled_softmax_forward wrapper Return FP16/BF16 tensor """ if not ScaledSoftmaxFwdPrimitive.enabled(): return jax_scaled_softmax(logits, scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) def scaled_masked_softmax_fwd( logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_masked_softmax_forward wrapper Return FP16/BF16 tensor """ if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): return jax_scaled_masked_softmax(logits, mask, scale_factor) return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, mask, scale_factor=scale_factor ) def scaled_masked_softmax_bwd( dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float, ) -> jnp.ndarray: """ scaled_masked_backward wrapper Return FP16/BF16 tensor """ if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp( partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask ) return vjp_func(dz)[0] return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor ) def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: """ scaled_upper_triang_masked_softmax_forward wrapper Return FP16/BF16 tensor """ if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): return jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, scale_factor=scale_factor ) def scaled_upper_triang_masked_softmax_bwd( dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float ) -> jnp.ndarray: """ scaled_upper_triang_masked_backward wrapper Return FP16/BF16 tensor """ if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): _, vjp_func = jax.vjp( partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits ) return vjp_func(dz)[0] return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor )