[JAX] Reduce L1 tests/jax/test_distributed_softmax.py test runtime (#2031)
* Pytest timings Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com> * Reduce softmax test shape sizes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Switch softmax tests to use shardy by default Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
Showing
Please register or sign in to comment