Unverified Commit 25a82192 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism...


[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism correctly for sequence-parallel inputs (#1980)

* updated GemmPrimitive partitioning rules to explicitly control all-reduce vs. reduce-scatter for sequence-parallelism
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected handling of FSDP sharding for the RHS operand
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use correct logical axes variable to identify sequence-parallel dim in LayerNormDenseGeneral
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed linting issues
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added assert on sequence-parallel options when GemmPrimitive is disabled
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a99c056b
......@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)
inner_primitive = None
outer_primitive = None
......@@ -177,8 +177,14 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
del (
sequence_parallel_output,
sequence_dim,
)
def _dims_are_consecutive(dims):
if len(dims) <= 1:
......@@ -343,8 +349,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
del sequence_parallel_output, sequence_dim
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
......@@ -393,6 +403,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
......@@ -430,6 +442,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
)
return outputs[:-3] # discard workspace arrays
......@@ -447,6 +461,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args
......@@ -489,6 +505,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
),
(out_bdims, bias_bdims, pre_gelu_bdims),
)
......@@ -510,7 +528,13 @@ class GemmPrimitive(BasePrimitive):
return bspecs, lspecs, cspecs
@staticmethod
def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims):
def _parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map(
......@@ -556,96 +580,66 @@ class GemmPrimitive(BasePrimitive):
)
# Extract single leading and contracting dimension specs
(lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map(
(lhs_cspec, rhs_cspec) = map(
lambda specs: None if len(specs) == 0 else specs[0],
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none),
)
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands.
# 1. K1 == K2 != None and N == None
# LHS: (B, M, K)
# RHS: (B, None, K)
# OUT: (B, M, None) --(AR)-> (B, M, None)
# 2. K1 == K2 != None and M == N != None
# LHS: (B, M, K)
# RHS: (B, N, K)--(AG)->(B, None, K)
# OUT: (B, M, None) --(RS)--> (B, M, N)
# 3. M == N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, M, K)--(AG)->(B, None, None)
# OUT: (B, M, None)
# 4. M != N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, N, K)--(AG)->(B, N, None)
# OUT: (B, M, N)
reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec
all_reduce_output = reduce_flag and rhs_lspec is None
reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec
all_reduce_spec = reduce_scatter_spec = scatter_dim = None
lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
)
out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs)
if reduce_scatter_output:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec = tuple(
rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim)
)
reduce_scatter_spec = lhs_cspec
scatter_dim = out_specs.index(rhs_lspec)
elif all_reduce_output:
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
all_reduce_spec = lhs_cspec
else:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
(lhs_cspec_not_none, rhs_cspec_not_none),
)
# Partitioning rules:
# ([B], M, K1) x ([B], N, K2)^T = ([B], M, N)
# 1. K1 == K2 != None
# - Require non-batched non-contracting dims of both LHS and RHS to be unsharded.
# - If `sequence_parallel_output=True`, then reduce-scatter the output.
# - Otherwise, all-reduce the output.
# 2. Otherwise
# - Require contracting dimensions of both LHS and RHS to be unsharded.
# - Require non-batched non-contracting dims of LHS to be unsharded.
reduce_output = rhs_cspec is not None and lhs_cspec == rhs_cspec
reduce_spec = scatter_dim = None
if reduce_output:
reduce_spec = rhs_cspec
if sequence_parallel_output:
# If the sequence dimension is not specified, assume it to be the first
# non-batched non-contracting dimension of the LHS operand.
scatter_dim = sequence_dim if sequence_dim is not None else lhs_ldims[0]
# Always require the non-batched non-contracting dims of LHS to be unsharded
# NOTE: This will all-gather sequence-parallel inputs and preserve tensor-parallel params.
lhs_specs = tuple(
None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i]
for i in range(lhs_ndim)
lhs_specs[i] if i in set(lhs_bdims + lhs_cdims) else None for i in range(lhs_ndim)
)
rhs_specs = tuple(
None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i]
for i in range(rhs_ndim)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if rhs_lspec is not None and rhs_lspec in tuple(
lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
if reduce_output:
# When reducing GEMM output, require non-batched non-contracting dims of the RHS
# operand to be unsharded (i.e. FSDP)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
else:
# Otherwise, require contracting dims of both operands to be unsharded
lhs_specs = tuple(None if i in lhs_cdims else lhs_specs[i] for i in range(lhs_ndim))
rhs_specs = tuple(None if i in rhs_cdims else rhs_specs[i] for i in range(rhs_ndim))
# Combine modified LHS and RHS specs into the output
lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
)
out_specs = [*lhs_non_contracting_specs, *rhs_non_contracting_specs]
# Bias and Pre-GeLU sharding is based on GEMM output before any scatter
bias_specs = tuple(list(out_specs[len(lhs_non_contracting_specs) :]).copy())
gelu_specs = tuple(list(out_specs).copy())
# Bias and Pre-GeLU sharding is based on GEMM output
bias_specs = out_specs[len(lhs_non_contracting_specs) :]
gelu_specs = out_specs
# Set output scatter dim to the tensor-parallel spec
if sequence_parallel_output:
out_specs[scatter_dim] = reduce_spec
return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
reduce_spec,
scatter_dim,
)
......@@ -661,6 +655,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh,
arg_infos,
result_infos,
......@@ -675,7 +671,13 @@ class GemmPrimitive(BasePrimitive):
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
......@@ -703,6 +705,8 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh,
arg_infos,
result_infos,
......@@ -712,10 +716,15 @@ class GemmPrimitive(BasePrimitive):
(
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
reduce_spec,
scatter_dim,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
) = GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
......@@ -770,20 +779,17 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
)
# All-Reduce/Reduce-Scatter GEMM output
if all_reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec)
elif reduce_scatter_spec is not None:
if reduce_spec is not None:
if scatter_dim is None:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
else:
outputs[0] = jax.lax.psum_scatter(
outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum_scatter(
outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
outputs[0], reduce_spec, scatter_dimension=scatter_dim, tiled=True
)
return outputs
......@@ -802,12 +808,14 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh,
operand_types,
result_types,
):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
del mesh, result_types
del sequence_parallel_output, sequence_dim, mesh, result_types
prefix = "GemmPrimitive_"
......@@ -896,6 +904,8 @@ def _te_gemm(
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
sequence_parallel_output: bool = False,
sequence_dim: int = None,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
......@@ -969,6 +979,8 @@ def _te_gemm(
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
)
......@@ -1307,9 +1319,9 @@ def gemm(
Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially
undesirable reduction in any batched contracting dimensions when invoked with sharded
operands (e.g. when computing weight gradients in a Flax module).
to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM
call to avoid a potentially undesirable reduction in any batched contracting dimensions
when invoked with sharded operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM.
......@@ -1327,7 +1339,17 @@ def gemm(
TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM.
sequence_parallel_output: bool, default = False
Produces an output with the first non-batched non-contracting dimension sharded with the
same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum`
for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to
cuBLAS GEMM.
sequence_dim: int, default = None
Index of the sequence dimension for the LHS operand. This controls which dimension of the
GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first
non-batched non-contracting dimension is assumed to be the sequence dimension.
Returns
-------
......@@ -1358,12 +1380,20 @@ def gemm(
if not GemmPrimitive.enabled():
assert kwargs.get("bias", None) is None and not fuse_gelu, (
"TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
"TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert (
not kwargs.get("sequence_parallel_output", False)
and kwargs.get("sequence_dim", None) is None
), (
"TE GEMM was invoked with sequence-parallelism options that are not supported by the "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
......
......@@ -22,6 +22,7 @@ from .quantize import (
TensorUsage,
)
from .sharding import get_sequence_parallel_dim
DENSE_BATCH_FIRST_WARNING_ISSUED = False
......@@ -41,6 +42,7 @@ def dense(
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
sequence_parallel_output: bool = False,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -55,6 +57,8 @@ def dense(
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
......@@ -69,13 +73,31 @@ def dense(
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
def _dense(
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -88,20 +110,38 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
)
return output
def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
x,
kernel,
bias,
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
):
"""Forward pass rule for dense layer transformation.
......@@ -161,6 +201,7 @@ def _dense_fwd_rule(
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(),
)
if use_bias and tex.gemm_uses_jax_dot():
......@@ -181,7 +222,7 @@ def _dense_fwd_rule(
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
......@@ -220,11 +261,22 @@ def _dense_bwd_rule(
k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
# Get sequence-parallel dimension of the FWD input (if it exists)
sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,))
dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=(
sequence_dim is not None
and not sequence_parallel_output
and not tex.gemm_uses_jax_dot()
),
sequence_dim=(
None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim
),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......
......@@ -415,6 +415,8 @@ class DenseGeneral(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
sequence_parallel_output: bool, default = False
Produce a sequence-parallel output with the first non-batch dimension sharded over
Optimization parameters
-----------------------
......@@ -439,6 +441,7 @@ class DenseGeneral(TransformerEngineBase):
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
sequence_parallel_output: bool = False
def __post_init__(self):
if self.transpose_batch_sequence:
......@@ -511,6 +514,7 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
sequence_parallel_output=self.sequence_parallel_output,
)
if self.enable_low_rank_adaptation:
......
......@@ -1425,6 +1425,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
name="out",
sequence_parallel_output=self.enable_sequence_parallel,
)(x)
out = checkpoint_name(out, "out_proj")
......
......@@ -24,6 +24,7 @@ from .quantize import (
with_sharding_constraint_by_logical_axes,
TensorUsage,
)
from .sharding import get_sequence_parallel_dim
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
......@@ -324,11 +325,16 @@ def _layernorm_dense_bwd_rule(
)
# NT GEMM
sequence_dim = get_sequence_parallel_dim(
layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,)
)
dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......
......@@ -29,7 +29,10 @@ from .quantize import (
noop_quantizer_set,
TensorUsage,
)
from .sharding import get_non_contracting_logical_axes
from .sharding import (
get_non_contracting_logical_axes,
get_sequence_parallel_dim,
)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
......@@ -342,6 +345,7 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in)
sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,))
dot_2_output = tex.gemm(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
......@@ -349,6 +353,8 @@ def _layernorm_mlp_fwd_rule(
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
)
if use_bias_2 and tex.gemm_uses_jax_dot():
......@@ -377,6 +383,7 @@ def _layernorm_mlp_fwd_rule(
use_bias_2,
quantizer_sets,
x_bdim,
sequence_dim,
)
return dot_2_output, ctx
......@@ -431,6 +438,7 @@ def _layernorm_mlp_bwd_rule(
use_bias_2,
quantizer_sets,
x_bdim,
sequence_dim,
) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
......@@ -501,6 +509,8 @@ def _layernorm_mlp_bwd_rule(
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......
......@@ -86,10 +86,52 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis
def generate_pspec(logical_axis_names):
def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims):
"""
Get the index for the sequence-parallel dimension based on the given logical axes.
The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting
dimension.
"""
if not logical_axes:
return None
pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True)
ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)]
lspecs = [pspec[i] for i in ldims if pspec[i] is not None]
if len(lspecs) == 0:
return None
assert len(lspecs) == 1, (
"Expected only 1 non-batched non-contracting dimension to be sharded for "
f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}"
)
return pspec.index(lspecs[0])
def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False):
"""
Convert logical axes to PartitionSpec
"""
rules = None
if with_flax_rules:
try:
import flax
rules = dict(flax.linen.get_logical_axis_rules())
except ImportError:
pass
if rules is None:
warnings.warn(
"Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated"
" and removed in a future version. Please use Flax logical axes with the"
" `flax.linen.logical_axis_rules()` context and optionally use"
" `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules"
" with Transformer Engine logical axes.",
DeprecationWarning,
)
rules = get_sharding_map_logic_axis_to_mesh_axis()
# mesh_axis_names = [rules[name] for name in logical_axis_names]
mesh_axis_names = []
......@@ -97,6 +139,8 @@ def generate_pspec(logical_axis_names):
axis_name = rules[name] if name in rules else None
mesh_axis_names.append(axis_name)
pspec = jax.sharding.PartitionSpec(*mesh_axis_names)
if padded:
pspec = get_padded_spec(pspec, len(mesh_axis_names))
return pspec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment