-
jberchtold-nvidia authored
Update jax_scaled_masked_softmax to match TE kernel implementation Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com>
4732ed76
Update jax_scaled_masked_softmax to match TE kernel implementation
Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>