"tests/vscode:/vscode.git/clone" did not exist on "151a0af650726b86740485d70eea17191d653a10"
Unverified Commit 04add79d authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Delay MeshResource validation until first usage (#2124)



Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 62a57dd4
...@@ -41,22 +41,32 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): ...@@ -41,22 +41,32 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
return mesh.shape[resource], resource return mesh.shape[resource], resource
def _validate_mesh_resource_configuration(): def _validate_mesh_resource_configuration(mesh_resource):
"""Validate that the mesh resource configuration is consistent and conflict-free.""" """Validate that the mesh resource configuration is consistent and conflict-free."""
gsr = global_mesh_resource() is_dp_enabled = (
mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1
is_dp_enabled = gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1 )
is_tp_enabled = gsr.tp_resource is not None and get_mesh_axis_size(gsr.tp_resource) > 1 is_tp_enabled = (
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1 )
is_tpsp_enabled = (
mesh_resource.tpsp_resource is not None
and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
)
is_fsdp_enabled = (
mesh_resource.fsdp_resource is not None
and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1
)
assert not (is_dp_enabled and is_fsdp_enabled), ( assert not (is_dp_enabled and is_fsdp_enabled), (
"Data parallelism and full-sharded data parallelism cannot be enabled at the same time." "Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
f" Got dp_resource={gsr.dp_resource} and fsdp_resource={gsr.fsdp_resource}" f" Got dp_resource={mesh_resource.dp_resource} and"
f" fsdp_resource={mesh_resource.fsdp_resource}"
) )
assert not (is_tp_enabled and is_tpsp_enabled), ( assert not (is_tp_enabled and is_tpsp_enabled), (
"Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time." "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
f" Got tp_resource={gsr.tp_resource} and tpsp_resource={gsr.tpsp_resource}" f" Got tp_resource={mesh_resource.tp_resource} and"
f" tpsp_resource={mesh_resource.tpsp_resource}"
) )
...@@ -305,7 +315,6 @@ def global_shard_guard(resource: MeshResource): ...@@ -305,7 +315,6 @@ def global_shard_guard(resource: MeshResource):
old_resources = _GLOBAL_MESH_RESOURCE old_resources = _GLOBAL_MESH_RESOURCE
try: try:
_GLOBAL_MESH_RESOURCE = resource _GLOBAL_MESH_RESOURCE = resource
_validate_mesh_resource_configuration()
yield yield
finally: finally:
_GLOBAL_MESH_RESOURCE = old_resources _GLOBAL_MESH_RESOURCE = old_resources
...@@ -322,6 +331,7 @@ def global_mesh_resource() -> MeshResource: ...@@ -322,6 +331,7 @@ def global_mesh_resource() -> MeshResource:
" context. If you are not using multiple GPUs, you can use an empty MeshResource by" " context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'" " wrapping your program in 'with global_shard_guard(MeshResource()):'"
) )
_validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
return _GLOBAL_MESH_RESOURCE return _GLOBAL_MESH_RESOURCE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment