Unverified Commit 4732ed76 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Update jax_scaled_masked_softmax to match TE kernel implementation (#1822)



Update jax_scaled_masked_softmax to match TE kernel implementation
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 30e30811
......@@ -809,13 +809,7 @@ def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fact
"""
JAX based implementation of scaled and masked softmax
"""
if mask is not None:
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
return jax.nn.softmax(logits * scale_factor, where=mask != 1)
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
......@@ -823,12 +817,7 @@ def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: fl
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
return jax_scaled_masked_softmax(logits, mask, scale_factor)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment