Unverified Commit 4732ed76 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Update jax_scaled_masked_softmax to match TE kernel implementation (#1822)



Update jax_scaled_masked_softmax to match TE kernel implementation
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 30e30811
...@@ -809,13 +809,7 @@ def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fact ...@@ -809,13 +809,7 @@ def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fact
""" """
JAX based implementation of scaled and masked softmax JAX based implementation of scaled and masked softmax
""" """
if mask is not None: return jax.nn.softmax(logits * scale_factor, where=mask != 1)
logits += jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
...@@ -823,12 +817,7 @@ def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: fl ...@@ -823,12 +817,7 @@ def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: fl
JAX based implementation of scaled and upper triangle masked softmax JAX based implementation of scaled and upper triangle masked softmax
""" """
mask = 1 - jnp.tril(jnp.ones_like(logits)) mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select( return jax_scaled_masked_softmax(logits, mask, scale_factor)
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return jax.nn.softmax(logits * scale_factor)
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment