Unverified Commit ee541e83 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

Stop using global mesh for custom_partitioning. (#1112)


Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 350a4ff1
...@@ -467,7 +467,7 @@ class ActLuFp8Primitive(BasePrimitive): ...@@ -467,7 +467,7 @@ class ActLuFp8Primitive(BasePrimitive):
local_x, local_amax = ActLuFp8Primitive.impl( local_x, local_amax = ActLuFp8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_enum
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, global_updated_amax return local_x, global_updated_amax
......
...@@ -1011,7 +1011,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -1011,7 +1011,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
) )
global_dbias = local_dbias global_dbias = local_dbias
if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
......
...@@ -533,8 +533,8 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -533,8 +533,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl( local_dx, local_dgamma, local_dbeta = LayerNormBwdPrimitive.impl(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
) )
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta) global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
return local_dx, global_dgamma, global_dbeta return local_dx, global_dgamma, global_dbeta
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
...@@ -935,7 +935,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -935,7 +935,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
def sharded_impl(dz, x, rsigma, gamma): def sharded_impl(dz, x, rsigma, gamma):
local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon) local_dx, local_dgamma = RmsNormBwdPrimitive.impl(dz, x, rsigma, gamma, epsilon=epsilon)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma) global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
return local_dx, global_dgamma return local_dx, global_dgamma
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
...@@ -1228,7 +1228,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1228,7 +1228,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_mu, local_rsigma, global_updated_amax return local_x, local_mu, local_rsigma, global_updated_amax
...@@ -1481,7 +1481,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1481,7 +1481,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl( local_x, local_rsigma, local_amax = RmsNormFwdFp8Primitive.impl(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_x, local_rsigma, global_updated_amax return local_x, local_rsigma, global_updated_amax
......
...@@ -157,7 +157,7 @@ class CastFP8Primitive(BasePrimitive): ...@@ -157,7 +157,7 @@ class CastFP8Primitive(BasePrimitive):
local_cx, local_updated_amax = CastFP8Primitive.impl( local_cx, local_updated_amax = CastFP8Primitive.impl(
x, amax, scale, scale_inv, out_dtype=out_dtype x, amax, scale, scale_inv, out_dtype=out_dtype
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, global_updated_amax return local_cx, global_updated_amax
......
...@@ -390,7 +390,7 @@ class CastTransposePrimitive(BasePrimitive): ...@@ -390,7 +390,7 @@ class CastTransposePrimitive(BasePrimitive):
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary, transpose_axis_boundary=transpose_axis_boundary,
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_updated_amax, mesh)
return local_cx, local_cxt, global_updated_amax return local_cx, local_cxt, global_updated_amax
...@@ -646,8 +646,8 @@ class DBiasCastTransposePrimitive(BasePrimitive): ...@@ -646,8 +646,8 @@ class DBiasCastTransposePrimitive(BasePrimitive):
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
transpose_axis_boundary=transpose_axis_boundary, transpose_axis_boundary=transpose_axis_boundary,
) )
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
...@@ -981,8 +981,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): ...@@ -981,8 +981,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive):
act_enum=act_enum, act_enum=act_enum,
) )
) )
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_dbias, global_updated_amax return local_out, local_t_out, global_dbias, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
...@@ -1225,7 +1225,7 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive): ...@@ -1225,7 +1225,7 @@ class DgatedActLuCastTransposePrimitive(BasePrimitive):
static_axis_boundary=static_axis_boundary, static_axis_boundary=static_axis_boundary,
act_enum=act_enum, act_enum=act_enum,
) )
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
return local_out, local_t_out, global_updated_amax return local_out, local_t_out, global_updated_amax
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
......
...@@ -30,8 +30,7 @@ W_TP_AXES = "nvte_w_tp" ...@@ -30,8 +30,7 @@ W_TP_AXES = "nvte_w_tp"
W_JOINED_AXES = "nvte_w_joined" W_JOINED_AXES = "nvte_w_joined"
def _get_mesh_info(resource: str): def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}." assert resource in mesh.axis_names, f"{resource} is not in the axis_names of Mesh {mesh}."
return mesh.shape[resource], resource return mesh.shape[resource], resource
...@@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim): ...@@ -132,12 +131,12 @@ def get_padded_spec(spec, ndim):
return spec + (None,) * (ndim - len(spec)) return spec + (None,) * (ndim - len(spec))
def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str): def lax_paral_op(x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh):
""" """
A wrapper function to invoke lax.p* operations, like psum. A wrapper function to invoke lax.p* operations, like psum.
""" """
if mesh_resource is not None: if mesh_resource is not None:
_, resource = _get_mesh_info(mesh_resource) _, resource = _get_mesh_info(mesh_resource, mesh)
return ops(x, resource) return ops(x, resource)
return x return x
...@@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource: ...@@ -201,22 +200,22 @@ def global_mesh_resource() -> MeshResource:
return _GLOBAL_MESH_RESOURCE return _GLOBAL_MESH_RESOURCE
def all_reduce_sum_along_dp_fsdp(x: jnp.array): def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
""" """
All-Reduce (Sum) along DP and FSDP mesh axes. All-Reduce (Sum) along DP and FSDP mesh axes.
""" """
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource) x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource) return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array): def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
""" """
All-Reduce (Max) along all mesh axes. All-Reduce (Max) along all mesh axes.
""" """
all_axes = get_all_mesh_axes() all_axes = get_all_mesh_axes()
for axis in all_axes: for axis in all_axes:
if axis != global_mesh_resource().pp_resource: if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis) x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment