[JAX] Manual axis filter in `with_sharding_constraint` (#2069)
* add manual axis filer to sharding_constraint impl Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> * fix lint Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * use abstract_mesh instead of physical_mesh Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * add a comment Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * cleanup Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> * clean unused var Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment