Unverified Commit 8dc2756e authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Manual axis filter in `with_sharding_constraint` (#2069)



* add manual axis filer to sharding_constraint impl
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix lint
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* use abstract_mesh instead of physical_mesh
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add a comment
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleanup
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean unused var
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent aa0659e5
...@@ -15,10 +15,10 @@ from dataclasses import dataclass ...@@ -15,10 +15,10 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
import warnings import warnings
from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import PartitionSpec from jax.interpreters import pxla
from jax.sharding import PartitionSpec, get_abstract_mesh
import numpy as np import numpy as np
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
...@@ -122,8 +122,10 @@ def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): ...@@ -122,8 +122,10 @@ def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False):
def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
""" """
A wrapper function to jax.lax.with_sharding_constraint to A wrapper function to jax.lax.with_sharding_constraint
support the case that Mesh is empty. 1. Does nothing if mesh is empty.
2. If all mesh axes are manual axes, replaces pspec with all Nones.
3. Otherwise, strips only the manual axes.
""" """
if pspec is None: if pspec is None:
return x return x
...@@ -131,7 +133,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): ...@@ -131,7 +133,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh.empty: if mesh.empty:
return x return x
return jax.lax.with_sharding_constraint(x, pspec)
# We want to exclude the axes that already used by shard_map and shard_map
# only sets those in the abstract_mesh, not the physical one
manual_axis_names = get_abstract_mesh().manual_axes
cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec)
cleaned_pspec = PartitionSpec(*cleaned_axis_names)
return jax.lax.with_sharding_constraint(x, cleaned_pspec)
def with_sharding_constraint_by_logical_axes( def with_sharding_constraint_by_logical_axes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment