"vscode:/vscode.git/clone" did not exist on "d399990607c482e40bdc915bcd90a06505919141"
Unverified Commit 006670de authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix mesh resource requirement when no mesh (#2307)



* Fix mesh resource requirement when no mesh
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* do not require meshresource if all axes are manual axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* remove abstract_mesh is None check
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 1269b2e2
...@@ -75,6 +75,16 @@ def get_sharding_map_logic_axis_to_mesh_axis(): ...@@ -75,6 +75,16 @@ def get_sharding_map_logic_axis_to_mesh_axis():
""" """
Generate a dict to map logical axes to mesh axes. Generate a dict to map logical axes to mesh axes.
""" """
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh is None or mesh.empty:
# If no mesh is defined, return an empty dict and do not require a MeshResource context to be present
return {}
abstract_mesh = get_abstract_mesh()
if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
# If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present
return {}
gsr = global_mesh_resource() gsr = global_mesh_resource()
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment