Unverified Commit 25a82192 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism...


[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism correctly for sequence-parallel inputs (#1980)

* updated GemmPrimitive partitioning rules to explicitly control all-reduce vs. reduce-scatter for sequence-parallelism
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected handling of FSDP sharding for the RHS operand
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use correct logical axes variable to identify sequence-parallel dim in LayerNormDenseGeneral
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting issues
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added assert on sequence-parallel options when GemmPrimitive is disabled
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a99c056b
...@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi" name = "te_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -177,8 +177,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -177,8 +177,14 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
del (
sequence_parallel_output,
sequence_dim,
)
def _dims_are_consecutive(dims): def _dims_are_consecutive(dims):
if len(dims) <= 1: if len(dims) <= 1:
...@@ -343,8 +349,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -343,8 +349,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
del sequence_parallel_output, sequence_dim
lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout( lhs_transposed, rhs_transposed = _get_gemm_layout(
...@@ -393,6 +403,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -393,6 +403,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout( lhs_transposed, rhs_transposed = _get_gemm_layout(
...@@ -430,6 +442,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -430,6 +442,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
) )
return outputs[:-3] # discard workspace arrays return outputs[:-3] # discard workspace arrays
...@@ -447,6 +461,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -447,6 +461,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
assert GemmPrimitive.outer_primitive is not None assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args lhs, _, rhs, *_ = batched_args
...@@ -489,6 +505,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -489,6 +505,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
), ),
(out_bdims, bias_bdims, pre_gelu_bdims), (out_bdims, bias_bdims, pre_gelu_bdims),
) )
...@@ -510,7 +528,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -510,7 +528,13 @@ class GemmPrimitive(BasePrimitive):
return bspecs, lspecs, cspecs return bspecs, lspecs, cspecs
@staticmethod @staticmethod
def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): def _parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map(
...@@ -556,96 +580,66 @@ class GemmPrimitive(BasePrimitive): ...@@ -556,96 +580,66 @@ class GemmPrimitive(BasePrimitive):
) )
# Extract single leading and contracting dimension specs # Extract single leading and contracting dimension specs
(lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( (lhs_cspec, rhs_cspec) = map(
lambda specs: None if len(specs) == 0 else specs[0], lambda specs: None if len(specs) == 0 else specs[0],
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), (lhs_cspec_not_none, rhs_cspec_not_none),
) )
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts # Partitioning rules:
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. # ([B], M, K1) x ([B], N, K2)^T = ([B], M, N)
# 1. K1 == K2 != None and N == None # 1. K1 == K2 != None
# LHS: (B, M, K) # - Require non-batched non-contracting dims of both LHS and RHS to be unsharded.
# RHS: (B, None, K) # - If `sequence_parallel_output=True`, then reduce-scatter the output.
# OUT: (B, M, None) --(AR)-> (B, M, None) # - Otherwise, all-reduce the output.
# 2. K1 == K2 != None and M == N != None # 2. Otherwise
# LHS: (B, M, K) # - Require contracting dimensions of both LHS and RHS to be unsharded.
# RHS: (B, N, K)--(AG)->(B, None, K) # - Require non-batched non-contracting dims of LHS to be unsharded.
# OUT: (B, M, None) --(RS)--> (B, M, N) reduce_output = rhs_cspec is not None and lhs_cspec == rhs_cspec
# 3. M == N reduce_spec = scatter_dim = None
# LHS: (B, M, K)--(AG)->(B, M, None) if reduce_output:
# RHS: (B, M, K)--(AG)->(B, None, None) reduce_spec = rhs_cspec
# OUT: (B, M, None) if sequence_parallel_output:
# 4. M != N # If the sequence dimension is not specified, assume it to be the first
# LHS: (B, M, K)--(AG)->(B, M, None) # non-batched non-contracting dimension of the LHS operand.
# RHS: (B, N, K)--(AG)->(B, N, None) scatter_dim = sequence_dim if sequence_dim is not None else lhs_ldims[0]
# OUT: (B, M, N)
reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec # Always require the non-batched non-contracting dims of LHS to be unsharded
all_reduce_output = reduce_flag and rhs_lspec is None # NOTE: This will all-gather sequence-parallel inputs and preserve tensor-parallel params.
reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec lhs_specs = tuple(
all_reduce_spec = reduce_scatter_spec = scatter_dim = None lhs_specs[i] if i in set(lhs_bdims + lhs_cdims) else None for i in range(lhs_ndim)
)
if reduce_output:
# When reducing GEMM output, require non-batched non-contracting dims of the RHS
# operand to be unsharded (i.e. FSDP)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
else:
# Otherwise, require contracting dims of both operands to be unsharded
lhs_specs = tuple(None if i in lhs_cdims else lhs_specs[i] for i in range(lhs_ndim))
rhs_specs = tuple(None if i in rhs_cdims else rhs_specs[i] for i in range(rhs_ndim))
# Combine modified LHS and RHS specs into the output
lhs_non_contracting_specs, rhs_non_contracting_specs = map( lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs), (lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims), (lhs_cdims, rhs_cdims),
) )
out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) out_specs = [*lhs_non_contracting_specs, *rhs_non_contracting_specs]
if reduce_scatter_output:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec = tuple(
rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim)
)
reduce_scatter_spec = lhs_cspec
scatter_dim = out_specs.index(rhs_lspec)
elif all_reduce_output:
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
all_reduce_spec = lhs_cspec
else:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
lhs_specs = tuple(
None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i]
for i in range(lhs_ndim)
)
rhs_specs = tuple(
None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i]
for i in range(rhs_ndim)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if rhs_lspec is not None and rhs_lspec in tuple(
lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
# Bias and Pre-GeLU sharding is based on GEMM output # Bias and Pre-GeLU sharding is based on GEMM output before any scatter
bias_specs = out_specs[len(lhs_non_contracting_specs) :] bias_specs = tuple(list(out_specs[len(lhs_non_contracting_specs) :]).copy())
gelu_specs = out_specs gelu_specs = tuple(list(out_specs).copy())
# Set output scatter dim to the tensor-parallel spec
if sequence_parallel_output:
out_specs[scatter_dim] = reduce_spec
return ( return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs), (lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs),
all_reduce_spec, reduce_spec,
reduce_scatter_spec,
scatter_dim, scatter_dim,
) )
...@@ -661,6 +655,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -661,6 +655,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -675,7 +671,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -675,7 +671,13 @@ class GemmPrimitive(BasePrimitive):
del use_split_accumulator, result_infos del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
...@@ -703,6 +705,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -703,6 +705,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -712,10 +716,15 @@ class GemmPrimitive(BasePrimitive): ...@@ -712,10 +716,15 @@ class GemmPrimitive(BasePrimitive):
( (
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs), (out_specs, dbias_specs, pre_gelu_specs),
all_reduce_spec, reduce_spec,
reduce_scatter_spec,
scatter_dim, scatter_dim,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) ) = GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
# Assemble argument shardings # Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
...@@ -770,20 +779,17 @@ class GemmPrimitive(BasePrimitive): ...@@ -770,20 +779,17 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
) )
# All-Reduce/Reduce-Scatter GEMM output # All-Reduce/Reduce-Scatter GEMM output
if all_reduce_spec is not None: if reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) if scatter_dim is None:
if fuse_gelu and not grad: outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) else:
elif reduce_scatter_spec is not None: outputs[0] = jax.lax.psum_scatter(
outputs[0] = jax.lax.psum_scatter( outputs[0], reduce_spec, scatter_dimension=scatter_dim, tiled=True
outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum_scatter(
outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
) )
return outputs return outputs
...@@ -802,12 +808,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -802,12 +808,14 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh, mesh,
operand_types, operand_types,
result_types, result_types,
): ):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
del mesh, result_types del sequence_parallel_output, sequence_dim, mesh, result_types
prefix = "GemmPrimitive_" prefix = "GemmPrimitive_"
...@@ -896,6 +904,8 @@ def _te_gemm( ...@@ -896,6 +904,8 @@ def _te_gemm(
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
sequence_parallel_output: bool = False,
sequence_dim: int = None,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
...@@ -969,6 +979,8 @@ def _te_gemm( ...@@ -969,6 +979,8 @@ def _te_gemm(
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
) )
...@@ -1307,9 +1319,9 @@ def gemm( ...@@ -1307,9 +1319,9 @@ def gemm(
Tuple of sequences representing the contracting dimensions of the operands. Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM
undesirable reduction in any batched contracting dimensions when invoked with sharded call to avoid a potentially undesirable reduction in any batched contracting dimensions
operands (e.g. when computing weight gradients in a Flax module). when invoked with sharded operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM. with TE's custom call to cuBLAS GEMM.
...@@ -1327,7 +1339,17 @@ def gemm( ...@@ -1327,7 +1339,17 @@ def gemm(
TE's custom call to cuBLAS GEMM. TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM.
sequence_parallel_output: bool, default = False
Produces an output with the first non-batched non-contracting dimension sharded with the
same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum`
for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to
cuBLAS GEMM.
sequence_dim: int, default = None
Index of the sequence dimension for the LHS operand. This controls which dimension of the
GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first
non-batched non-contracting dimension is assumed to be the sequence dimension.
Returns Returns
------- -------
...@@ -1358,12 +1380,20 @@ def gemm( ...@@ -1358,12 +1380,20 @@ def gemm(
if not GemmPrimitive.enabled(): if not GemmPrimitive.enabled():
assert kwargs.get("bias", None) is None and not fuse_gelu, ( assert kwargs.get("bias", None) is None and not fuse_gelu, (
"TE GEMM was invoked with bias fusion options that are not supported by the " "TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled." "GEMM primitive is disabled."
) )
assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
"TE GEMM was invoked with GeLU fusion options that are not supported by the " "TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert (
not kwargs.get("sequence_parallel_output", False)
and kwargs.get("sequence_dim", None) is None
), (
"TE GEMM was invoked with sequence-parallelism options that are not supported by the "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS "
"GEMM primitive is disabled." "GEMM primitive is disabled."
) )
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
......
...@@ -22,6 +22,7 @@ from .quantize import ( ...@@ -22,6 +22,7 @@ from .quantize import (
TensorUsage, TensorUsage,
) )
from .sharding import get_sequence_parallel_dim
DENSE_BATCH_FIRST_WARNING_ISSUED = False DENSE_BATCH_FIRST_WARNING_ISSUED = False
...@@ -41,6 +42,7 @@ def dense( ...@@ -41,6 +42,7 @@ def dense(
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True, batch_first: bool = True,
sequence_parallel_output: bool = False,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -55,6 +57,8 @@ def dense( ...@@ -55,6 +57,8 @@ def dense(
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension. batch_first: Assume that X is batched in the first dimension.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
...@@ -69,13 +73,31 @@ def dense( ...@@ -69,13 +73,31 @@ def dense(
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense( output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): def _dense(
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -88,20 +110,38 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir ...@@ -88,20 +110,38 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule( output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
) )
return output return output
def _dense_fwd_rule( def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -161,6 +201,7 @@ def _dense_fwd_rule( ...@@ -161,6 +201,7 @@ def _dense_fwd_rule(
batched_dims=((x_bdim,), ()), batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(),
) )
if use_bias and tex.gemm_uses_jax_dot(): if use_bias and tex.gemm_uses_jax_dot():
...@@ -181,7 +222,7 @@ def _dense_fwd_rule( ...@@ -181,7 +222,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
...@@ -220,11 +261,22 @@ def _dense_bwd_rule( ...@@ -220,11 +261,22 @@ def _dense_bwd_rule(
k_contracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
# Get sequence-parallel dimension of the FWD input (if it exists)
sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,))
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()), batched_dims=((x_bdim,), ()),
sequence_parallel_output=(
sequence_dim is not None
and not sequence_parallel_output
and not tex.gemm_uses_jax_dot()
),
sequence_dim=(
None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim
),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......
...@@ -415,6 +415,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -415,6 +415,8 @@ class DenseGeneral(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input, like Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint. sharding constraint.
sequence_parallel_output: bool, default = False
Produce a sequence-parallel output with the first non-batch dimension sharded over
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -439,6 +441,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -439,6 +441,7 @@ class DenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
sequence_parallel_output: bool = False
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -511,6 +514,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -511,6 +514,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes, input_axes=self.input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
sequence_parallel_output=self.sequence_parallel_output,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
......
...@@ -1425,6 +1425,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1425,6 +1425,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
name="out", name="out",
sequence_parallel_output=self.enable_sequence_parallel,
)(x) )(x)
out = checkpoint_name(out, "out_proj") out = checkpoint_name(out, "out_proj")
......
...@@ -24,6 +24,7 @@ from .quantize import ( ...@@ -24,6 +24,7 @@ from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage, TensorUsage,
) )
from .sharding import get_sequence_parallel_dim
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
...@@ -324,11 +325,16 @@ def _layernorm_dense_bwd_rule( ...@@ -324,11 +325,16 @@ def _layernorm_dense_bwd_rule(
) )
# NT GEMM # NT GEMM
sequence_dim = get_sequence_parallel_dim(
layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,)
)
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()), batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......
...@@ -29,7 +29,10 @@ from .quantize import ( ...@@ -29,7 +29,10 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
) )
from .sharding import get_non_contracting_logical_axes from .sharding import (
get_non_contracting_logical_axes,
get_sequence_parallel_dim,
)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
...@@ -342,6 +345,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -342,6 +345,7 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,))
dot_2_output = tex.gemm( dot_2_output = tex.gemm(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
...@@ -349,6 +353,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -349,6 +353,8 @@ def _layernorm_mlp_fwd_rule(
batched_dims=((x_bdim,), ()), batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
) )
if use_bias_2 and tex.gemm_uses_jax_dot(): if use_bias_2 and tex.gemm_uses_jax_dot():
...@@ -377,6 +383,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -377,6 +383,7 @@ def _layernorm_mlp_fwd_rule(
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim, x_bdim,
sequence_dim,
) )
return dot_2_output, ctx return dot_2_output, ctx
...@@ -431,6 +438,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -431,6 +438,7 @@ def _layernorm_mlp_bwd_rule(
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim, x_bdim,
sequence_dim,
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -501,6 +509,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -501,6 +509,8 @@ def _layernorm_mlp_bwd_rule(
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()), batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......
...@@ -86,17 +86,61 @@ def get_sharding_map_logic_axis_to_mesh_axis(): ...@@ -86,17 +86,61 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names): def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims):
"""
Get the index for the sequence-parallel dimension based on the given logical axes.
The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting
dimension.
"""
if not logical_axes:
return None
pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True)
ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)]
lspecs = [pspec[i] for i in ldims if pspec[i] is not None]
if len(lspecs) == 0:
return None
assert len(lspecs) == 1, (
"Expected only 1 non-batched non-contracting dimension to be sharded for "
f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}"
)
return pspec.index(lspecs[0])
def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False):
""" """
Convert logical axes to PartitionSpec Convert logical axes to PartitionSpec
""" """
rules = get_sharding_map_logic_axis_to_mesh_axis() rules = None
if with_flax_rules:
try:
import flax
rules = dict(flax.linen.get_logical_axis_rules())
except ImportError:
pass
if rules is None:
warnings.warn(
"Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated"
" and removed in a future version. Please use Flax logical axes with the"
" `flax.linen.logical_axis_rules()` context and optionally use"
" `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules"
" with Transformer Engine logical axes.",
DeprecationWarning,
)
rules = get_sharding_map_logic_axis_to_mesh_axis()
# mesh_axis_names = [rules[name] for name in logical_axis_names] # mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = [] mesh_axis_names = []
for name in logical_axis_names: for name in logical_axis_names:
axis_name = rules[name] if name in rules else None axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name) mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names) pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
if padded:
pspec = get_padded_spec(pspec, len(mesh_axis_names))
return pspec return pspec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment