Unverified Commit be7f43f1 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229)



Fix shard map issue
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent f936c2ac
......@@ -131,7 +131,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
# We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one
manual_axis_names = get_abstract_mesh().manual_axes
cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec)
# Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too
def filter_manual_axes(name_or_tuple):
if isinstance(name_or_tuple, tuple):
out = tuple(n for n in name_or_tuple if n not in manual_axis_names)
if len(out) == 0:
return None
return out
if name_or_tuple in manual_axis_names:
return None
return name_or_tuple
cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec)
if cleaned_axis_names == (None,) * len(cleaned_axis_names):
return x
cleaned_pspec = PartitionSpec(*cleaned_axis_names)
return jax.lax.with_sharding_constraint(x, cleaned_pspec)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment