Unverified Commit d46d5db4 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Handle meshs set with jax.set_mesh (#2532)



* Handle meshs set with jax.set_mesh
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6fd62098
...@@ -37,6 +37,15 @@ W_TP_AXES = "nvte_w_tp" ...@@ -37,6 +37,15 @@ W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined" W_JOINED_AXES = "nvte_w_joined"
def _get_mesh():
# Handle Mesh's set via `with mesh:`
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh is not None and not mesh.empty:
return mesh
# Handle Mesh's set via `jax.set_mesh(mesh)`
return jax.sharding.get_abstract_mesh()
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource return mesh.shape[resource], resource
...@@ -63,7 +72,7 @@ def is_mesh_available() -> bool: ...@@ -63,7 +72,7 @@ def is_mesh_available() -> bool:
""" """
Check if a physical mesh is available. Check if a physical mesh is available.
""" """
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _get_mesh()
return mesh is not None and not mesh.empty return mesh is not None and not mesh.empty
...@@ -71,7 +80,7 @@ def get_sharding_map_logic_axis_to_mesh_axis(): ...@@ -71,7 +80,7 @@ def get_sharding_map_logic_axis_to_mesh_axis():
""" """
Generate a dict to map logical axes to mesh axes. Generate a dict to map logical axes to mesh axes.
""" """
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _get_mesh()
if mesh is None or mesh.empty: if mesh is None or mesh.empty:
# If no mesh is defined, return an empty dict and do not require a MeshResource context to be present # If no mesh is defined, return an empty dict and do not require a MeshResource context to be present
return {} return {}
...@@ -130,7 +139,7 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): ...@@ -130,7 +139,7 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
if pspec is None: if pspec is None:
return x return x
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _get_mesh()
if mesh.empty: if mesh.empty:
return x return x
...@@ -211,7 +220,7 @@ def get_all_mesh_axes(): ...@@ -211,7 +220,7 @@ def get_all_mesh_axes():
""" """
Get all name of mesh axes Get all name of mesh axes
""" """
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _get_mesh()
return mesh.axis_names return mesh.axis_names
...@@ -251,7 +260,7 @@ def get_num_devices_in_mesh(mesh=None): ...@@ -251,7 +260,7 @@ def get_num_devices_in_mesh(mesh=None):
by the global mesh. by the global mesh.
""" """
if mesh is None: if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _get_mesh()
if mesh.empty: if mesh.empty:
return 1 return 1
return np.prod(list(mesh.shape.values())) return np.prod(list(mesh.shape.values()))
...@@ -264,7 +273,7 @@ def get_mesh_axis_size(axis, mesh=None): ...@@ -264,7 +273,7 @@ def get_mesh_axis_size(axis, mesh=None):
by the global mesh. by the global mesh.
""" """
if mesh is None: if mesh is None:
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _get_mesh()
if axis is None: if axis is None:
return 1 return 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment