[JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229)
Fix shard map issue Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment