Unverified Commit ee541e83 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

Stop using global mesh for custom_partitioning. (#1112)


Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 350a4ff1
......@@ -467,7 +467,7 @@ class ActLuFp8Primitive(BasePrimitive):
local_x, local_amax = ActLuFp8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, global_updated_amax
......
......@@ -1011,7 +1011,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
)
global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
......
......@@ -533,8 +533,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
return local_dx, global_dgamma, global_dbeta
return mesh, sharded_impl, out_shardings, arg_shardings
......@@ -935,7 +935,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
return local_dx, global_dgamma
return mesh, sharded_impl, out_shardings, arg_shardings
......@@ -1228,7 +1228,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_mu, local_rsigma, global_updated_amax
......@@ -1481,7 +1481,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_rsigma, global_updated_amax
......
......@@ -157,7 +157,7 @@ class CastFP8Primitive(BasePrimitive):
local_cx, local_updated_amax = CastFP8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, global_updated_amax
......
......@@ -390,7 +390,7 @@ class CastTransposePrimitive(BasePrimitive):
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, local_cxt, global_updated_amax
......@@ -646,8 +646,8 @@ class DBiasCastTransposePrimitive(BasePrimitive):
static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary,
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
......@@ -981,8 +981,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
act_enum=act_enum,
)
)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
......@@ -1225,7 +1225,7 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive):
static_axis_boundary=static_axis_boundary,
act_enum=act_enum,
)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings
......
......@@ -30,8 +30,7 @@ W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined"
def _get_mesh_info(resource: str):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource
......@@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec))
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str):
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh):
"""
A wrapper function to invoke lax.p* operations, like psum.
"""
if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource)
_, resource = _get_mesh_info(mesh_resource, mesh)
return ops(x, resource)
return x
......@@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource:
return _GLOBAL_MESH_RESOURCE
def all_reduce_sum_along_dp_fsdp(x: jnp.array):
def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Sum) along DP and FSDP mesh axes.
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array):
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""
All-Reduce (Max) along all mesh axes.
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis)
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment