[JAX] Update jax_scaled_masked_softmax to match TE kernel implementation (#1822)
Update jax_scaled_masked_softmax to match TE kernel implementation
Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
Showing
Please register or sign in to comment