Commit 3e6859e2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Sharding specs for TE GEMM custom call operands (#2023)



* new gemm operand specs processing
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix for lhs_non_specs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 06947e87
...@@ -511,22 +511,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -511,22 +511,6 @@ class GemmPrimitive(BasePrimitive):
(out_bdims, bias_bdims, pre_gelu_bdims), (out_bdims, bias_bdims, pre_gelu_bdims),
) )
@staticmethod
def _decompose_operand_specs(specs, contracting_dims, batch_dims):
ndims = len(specs)
cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims))
# Batch specs
bspecs = tuple(specs[i] for i in bdims)
# Non-batch leading dimension specs
lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims)
# Non-batch contracting dimension specs
cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims)
return bspecs, lspecs, cspecs
@staticmethod @staticmethod
def _parse_operand_output_specs( def _parse_operand_output_specs(
arg_infos, arg_infos,
...@@ -535,112 +519,74 @@ class GemmPrimitive(BasePrimitive): ...@@ -535,112 +519,74 @@ class GemmPrimitive(BasePrimitive):
sequence_parallel_output, sequence_parallel_output,
sequence_dim, sequence_dim,
): ):
del sequence_dim, sequence_parallel_output, batched_dims
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map(
sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims
)
(
(lhs_bspecs, lhs_lspecs, lhs_cspecs),
(rhs_bspecs, rhs_lspecs, rhs_cspecs),
) = map(
GemmPrimitive._decompose_operand_specs,
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
# Batched dimensions must have the same sharding
if len(lhs_bdims) > 0 and len(rhs_bdims) > 0:
assert all(
lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs)
), (
"cuBLAS GEMM operand batch dimensions must have the same sharding: "
f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}."
)
# Only one each of the non-batched leading dimensions and non-batched contracting lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
# dimensions can be sharded lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
lhs_ldims, rhs_ldims = map( lhs_non_cdims, rhs_non_cdims = map(
lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims),
(lhs_ndim, rhs_ndim), (lhs_ndim, rhs_ndim),
(lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), (lhs_cdims, rhs_cdims),
)
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map(
lambda specs: tuple(spec for spec in specs if spec is not None),
(lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs),
)
assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched leading dimension: "
f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}."
)
assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: "
f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}."
) )
lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map(
# Extract single leading and contracting dimension specs lambda specs, dims: tuple(specs[i] for i in dims),
(lhs_cspec, rhs_cspec) = map( (lhs_specs, lhs_specs, rhs_specs, rhs_specs),
lambda specs: None if len(specs) == 0 else specs[0], (lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims),
(lhs_cspec_not_none, rhs_cspec_not_none),
) )
# Partitioning rules: reduce_spec = None
# ([B], M, K1) x ([B], N, K2)^T = ([B], M, N) for l in lhs_cspecs:
# 1. K1 == K2 != None for r in rhs_cspecs:
# - Require non-batched non-contracting dims of both LHS and RHS to be unsharded. if l is not None and l == r:
# - If `sequence_parallel_output=True`, then reduce-scatter the output. assert reduce_spec is None, "Multiple reduce dimension is detected!"
# - Otherwise, all-reduce the output. reduce_spec = l
# 2. Otherwise
# - Require contracting dimensions of both LHS and RHS to be unsharded. if reduce_spec is not None:
# - Require non-batched non-contracting dims of LHS to be unsharded. # Other non-reduce cdims (if exists) need to be unsharded
reduce_output = rhs_cspec is not None and lhs_cspec == rhs_cspec lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
reduce_spec = scatter_dim = None rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
if reduce_output:
reduce_spec = rhs_cspec # Non-batched non-contracting dims of RHS needs to be unsharded (i.e. FSDP)
if sequence_parallel_output: # Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim
# If the sequence dimension is not specified, assume it to be the first # rhs_specs only includes batch-dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs
# non-batched non-contracting dimension of the LHS operand. rhs_non_cspecs = tuple(
scatter_dim = sequence_dim if sequence_dim is not None else lhs_ldims[0] None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs
# Always require the non-batched non-contracting dims of LHS to be unsharded
# NOTE: This will all-gather sequence-parallel inputs and preserve tensor-parallel params.
lhs_specs = tuple(
lhs_specs[i] if i in set(lhs_bdims + lhs_cdims) else None for i in range(lhs_ndim)
)
if reduce_output:
# When reducing GEMM output, require non-batched non-contracting dims of the RHS
# operand to be unsharded (i.e. FSDP)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
) )
else: else:
# Otherwise, require contracting dims of both operands to be unsharded # Otherwise, require contracting dims of both operands to be unsharded
lhs_specs = tuple(None if i in lhs_cdims else lhs_specs[i] for i in range(lhs_ndim)) lhs_cspecs = (None,) * len(lhs_cspecs)
rhs_specs = tuple(None if i in rhs_cdims else rhs_specs[i] for i in range(rhs_ndim)) rhs_cspecs = (None,) * len(rhs_cspecs)
# Combine modified LHS and RHS specs into the output # Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim
lhs_non_contracting_specs, rhs_non_contracting_specs = map( # The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), # rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be
(lhs_specs, rhs_specs), # overwrite
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs)
out_specs = lhs_non_cspecs + rhs_non_cspecs
# specs = merge(cspecs, non_cspecs)
lhs_specs, rhs_specs = map(
lambda cdims, cspecs, non_cspecs: (
cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs
),
(lhs_cdims, rhs_cdims), (lhs_cdims, rhs_cdims),
(lhs_cspecs, rhs_cspecs),
(lhs_non_cspecs, rhs_non_cspecs),
) )
out_specs = [*lhs_non_contracting_specs, *rhs_non_contracting_specs]
# Bias and Pre-GeLU sharding is based on GEMM output before any scatter # Bias and Pre-GeLU sharding is based on GEMM output before any scatter
bias_specs = tuple(list(out_specs[len(lhs_non_contracting_specs) :]).copy()) bias_specs = tuple(list(rhs_non_cspecs).copy())
gelu_specs = tuple(list(out_specs).copy()) gelu_specs = tuple(list(out_specs).copy())
# Set output scatter dim to the tensor-parallel spec
if sequence_parallel_output:
out_specs[scatter_dim] = reduce_spec
return ( return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs), (lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs),
reduce_spec, reduce_spec,
scatter_dim, 0,
) )
@staticmethod @staticmethod
...@@ -717,7 +663,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -717,7 +663,7 @@ class GemmPrimitive(BasePrimitive):
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs), (out_specs, dbias_specs, pre_gelu_specs),
reduce_spec, reduce_spec,
scatter_dim, _,
) = GemmPrimitive._parse_operand_output_specs( ) = GemmPrimitive._parse_operand_output_specs(
arg_infos, arg_infos,
contracting_dims, contracting_dims,
...@@ -785,12 +731,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -785,12 +731,7 @@ class GemmPrimitive(BasePrimitive):
# All-Reduce/Reduce-Scatter GEMM output # All-Reduce/Reduce-Scatter GEMM output
if reduce_spec is not None: if reduce_spec is not None:
if scatter_dim is None: outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
else:
outputs[0] = jax.lax.psum_scatter(
outputs[0], reduce_spec, scatter_dimension=scatter_dim, tiled=True
)
return outputs return outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment