[JAX] Support Flax sharding constraints (#1933)
* Support flax sharding constraints Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com> * Add warning for deprecated TE logical axes Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * Update examples Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
Showing
Please register or sign in to comment