"tests/pytorch/attention/test_kv_cache.py" did not exist on "65c2798a720a36e4499a75592e9caa8ae8d8996c"
Unverified Commit cae1c436 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] TE Gemm custom call clean up (#2030)



* rm batch_dim, sequence_dim, sequence_parallel_output
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm lhs_quantized_colwise and rhs_quantized_colwise
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm unnecessary transpose_batch_sequence arg from some modules
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dd083bdf
......@@ -333,7 +333,6 @@ class TestDistributedLayernormMLP:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
use_bias=use_bias,
......@@ -352,7 +351,6 @@ class TestDistributedLayernormMLP:
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=LN_SCALE_AXES,
......
......@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)
impl_static_args = (6, 7, 8, 9, 10, 11, 12)
inner_primitive = None
outer_primitive = None
......@@ -169,22 +169,13 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
del (
sequence_parallel_output,
sequence_dim,
)
del use_split_accumulator
def _dims_are_consecutive(dims):
if len(dims) <= 1:
......@@ -207,27 +198,6 @@ class GemmPrimitive(BasePrimitive):
f"{rhs_contracting_dims}."
)
(
lhs_batch_dims,
rhs_batch_dims,
) = map(sanitize_dims, operand_ndims, batched_dims)
assert _dims_are_consecutive(lhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f"{lhs_batch_dims}."
)
assert _dims_are_consecutive(rhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f"{rhs_batch_dims}."
)
if len(lhs_batch_dims) == 0:
assert (
len(rhs_batch_dims) == 0
), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif len(rhs_batch_dims) != 0:
assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
bdim in rhs_contracting_dims for bdim in rhs_batch_dims
), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape),
......@@ -341,19 +311,13 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
del sequence_parallel_output, sequence_dim
del out_dtype
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
......@@ -395,16 +359,11 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
......@@ -414,14 +373,14 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_inv,
scaling_mode,
lhs.shape,
is_colwise=lhs_quantized_colwise,
is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv,
scaling_mode,
rhs.shape,
is_colwise=rhs_quantized_colwise,
is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
)
......@@ -434,55 +393,34 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
)
return outputs[:-3] # discard workspace arrays
@staticmethod
def batcher(
batched_args,
jax_batch_dims,
batch_dims,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
):
assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args
lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
lhs_bdims, _, rhs_bdims, *_ = batch_dims
# Output is batched like the non-contracting batch dimensions of the LHS operand
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims)
lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims)
out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims
# Batched GEMM is not supported
assert (
lhs_bdims is None and rhs_bdims is None
), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})"
out_bdims = (None,)
# Bias gradient is never batched
bias_bdims = (None,)
......@@ -497,16 +435,11 @@ class GemmPrimitive(BasePrimitive):
*batched_args,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
),
(out_bdims, bias_bdims, pre_gelu_bdims),
)
......@@ -515,11 +448,7 @@ class GemmPrimitive(BasePrimitive):
def _parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
):
del sequence_dim, sequence_parallel_output, batched_dims
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
......@@ -586,44 +515,30 @@ class GemmPrimitive(BasePrimitive):
(lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs),
reduce_spec,
0,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh,
arg_infos,
result_infos,
):
del (
out_dtype,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
grad,
)
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
(_, (out_specs, dbias_specs, pre_gelu_specs), _) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
......@@ -643,16 +558,11 @@ class GemmPrimitive(BasePrimitive):
def partition(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh,
arg_infos,
result_infos,
......@@ -663,14 +573,7 @@ class GemmPrimitive(BasePrimitive):
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs),
reduce_spec,
_,
) = GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
......@@ -717,19 +620,14 @@ class GemmPrimitive(BasePrimitive):
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
)
# All-Reduce/Reduce-Scatter GEMM output
# All-Reduce GEMM output
if reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
......@@ -741,54 +639,42 @@ class GemmPrimitive(BasePrimitive):
def shardy_sharding_rule(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh,
operand_types,
result_types,
):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
del sequence_parallel_output, sequence_dim, mesh, result_types
del out_dtype, grad, use_split_accumulator
del mesh, result_types
prefix = "GemmPrimitive_"
def _generate_operand_rules(name, ndim, cdims, bdims):
def _generate_operand_rules(name, ndim, cdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in bdims + cdims)
ldims = tuple(i for i in range(ndim) if i not in cdims)
for i in range(ndim):
dim_name = None
if i in bdims:
dim_idx = bdims.index(i) if len(bdims) > 1 else ""
dim_name = f"b{dim_idx}"
elif i in cdims:
dim_idx = cdims.index(i) if len(cdims) > 1 else ""
if i in cdims:
dim_idx = cdims.index(i)
dim_name = f"k{dim_idx}"
else:
dim_idx = ldims.index(i) if len(ldims) > 1 else ""
dim_idx = ldims.index(i)
dim_name = f"{name}_l{dim_idx}"
specs.append(prefix + dim_name)
return specs
lhs, _, rhs, *_ = operand_types
operand_ndims = (len(lhs.shape), len(rhs.shape))
(lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map(
lambda dims: map(sanitize_dims, operand_ndims, dims),
(contracting_dims, batched_dims),
)
(lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims)
lhs_specs, rhs_specs = map(
_generate_operand_rules,
("lhs", "rhs"),
operand_ndims,
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",)
......@@ -840,13 +726,10 @@ def _te_gemm(
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
sequence_parallel_output: bool = False,
sequence_dim: int = None,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
......@@ -857,7 +740,6 @@ def _te_gemm(
scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
# Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
......@@ -876,7 +758,6 @@ def _te_gemm(
lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
......@@ -894,7 +775,6 @@ def _te_gemm(
rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
# Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
......@@ -912,16 +792,11 @@ def _te_gemm(
gelu_input,
out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims),
batched_dims=(lhs_bdims, rhs_bdims),
lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
)
......@@ -1124,10 +999,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
if rhs.data_layout == "T":
rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
......@@ -1239,7 +1112,6 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
**kwargs,
......@@ -1258,11 +1130,6 @@ def gemm(
Object for down-casting the RHS operand for quantized GEMM.
contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM
call to avoid a potentially undesirable reduction in any batched contracting dimensions
when invoked with sharded operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM.
......@@ -1282,15 +1149,6 @@ def gemm(
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM.
sequence_parallel_output: bool, default = False
Produces an output with the first non-batched non-contracting dimension sharded with the
same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum`
for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to
cuBLAS GEMM.
sequence_dim: int, default = None
Index of the sequence dimension for the LHS operand. This controls which dimension of the
GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first
non-batched non-contracting dimension is assumed to be the sequence dimension.
Returns
-------
......@@ -1329,14 +1187,6 @@ def gemm(
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert (
not kwargs.get("sequence_parallel_output", False)
and kwargs.get("sequence_dim", None) is None
), (
"TE GEMM was invoked with sequence-parallelism options that are not supported by the "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
outputs = _te_gemm(
......@@ -1345,7 +1195,6 @@ def gemm(
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
**kwargs,
)
......
......@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
import warnings
from typing import Tuple, Sequence
from functools import partial
import jax
......@@ -22,17 +22,6 @@ from .quantize import (
TensorUsage,
)
from .sharding import get_sequence_parallel_dim
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense(
x: jnp.ndarray,
......@@ -41,8 +30,6 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
sequence_parallel_output: bool = False,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -56,9 +43,6 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
......@@ -79,14 +63,19 @@ def dense(
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
@partial(
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
def _dense(
x,
kernel,
......@@ -94,8 +83,6 @@ def _dense(
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
):
"""Internal implementation of dense layer transformation with custom VJP.
......@@ -110,9 +97,6 @@ def _dense(
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
......@@ -125,8 +109,6 @@ def _dense(
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
)
return output
......@@ -139,8 +121,6 @@ def _dense_fwd_rule(
contracting_dims,
input_axes,
kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set,
):
"""Forward pass rule for dense layer transformation.
......@@ -159,23 +139,6 @@ def _dense_fwd_rule(
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
......@@ -198,10 +161,8 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(),
)
if use_bias and tex.gemm_uses_jax_dot():
......@@ -216,13 +177,12 @@ def _dense_fwd_rule(
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
)
return output, ctx
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad
contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
......@@ -237,7 +197,6 @@ def _dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
......@@ -262,21 +221,10 @@ def _dense_bwd_rule(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
)
# Get sequence-parallel dimension of the FWD input (if it exists)
sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,))
dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=(
sequence_dim is not None
and not sequence_parallel_output
and not tex.gemm_uses_jax_dot()
),
sequence_dim=(
None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim
),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......@@ -290,7 +238,6 @@ def _dense_bwd_rule(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from ..dense import dense, _issue_batch_first_warning as _dense_warning
from ..dense import dense
from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -273,10 +273,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
epsilon: float = 1e-6
......@@ -287,7 +283,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
......@@ -414,17 +409,11 @@ class DenseGeneral(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
sequence_parallel_output: bool, default = False
Produce a sequence-parallel output with the first non-batch dimension sharded over
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
features: Union[Iterable[int], int]
......@@ -438,17 +427,9 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = ()
sequence_parallel_output: bool = False
def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......@@ -513,7 +494,6 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes,
kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set,
sequence_parallel_output=self.sequence_parallel_output,
)
if self.enable_low_rank_adaptation:
......@@ -631,10 +611,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied.
......@@ -660,18 +636,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0,
......@@ -949,10 +918,6 @@ class LayerNormMLP(TransformerEngineBase):
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
"""
intermediate_dim: int = 2048
......@@ -981,7 +946,6 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None
......@@ -989,12 +953,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name: str = "ffn2"
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
......@@ -1167,7 +1167,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_attention_heads * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.return_layernorm_output,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......@@ -1194,7 +1193,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=(self.return_layernorm_output or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......@@ -1219,7 +1217,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_proj = DenseGeneral(
axis=-1,
features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init,
use_bias=self.use_bias,
......@@ -1238,7 +1235,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias,
bias_init=self.bias_init,
......@@ -1255,7 +1251,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
......@@ -1420,7 +1415,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
out = DenseGeneral(
features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=(W_TP_AXES, W_FSDP_AXES),
......@@ -1432,7 +1426,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
name="out",
sequence_parallel_output=self.enable_sequence_parallel,
)(x)
out = checkpoint_name(out, "out_proj")
......@@ -2023,7 +2016,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations,
......@@ -2078,7 +2070,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon,
scale_axes=(W_NO_SHARD_AXES,),
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
name="output_layernorm",
)(z)
......
......@@ -9,7 +9,6 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
import warnings
from functools import partial
from typing import Tuple
......@@ -24,17 +23,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes,
TensorUsage,
)
from .sharding import get_sequence_parallel_dim
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense(
......@@ -49,7 +37,6 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -70,7 +57,6 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -94,7 +80,6 @@ def layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -109,7 +94,6 @@ def layernorm_dense(
8,
9,
10,
11,
),
)
def _layernorm_dense(
......@@ -124,7 +108,6 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -144,7 +127,6 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers
Returns:
......@@ -162,7 +144,6 @@ def _layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -180,7 +161,6 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -198,17 +178,6 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
......@@ -237,7 +206,6 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
......@@ -261,7 +229,6 @@ def _layernorm_dense_fwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
)
return output, ctx
......@@ -272,9 +239,8 @@ def _layernorm_dense_bwd_rule(
zero_centered_gamma,
epsilon,
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
dot_input_axes,
kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx,
grad,
):
......@@ -289,6 +255,7 @@ def _layernorm_dense_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del dot_input_axes
(
casted_ln_out,
casted_kernel,
......@@ -304,7 +271,6 @@ def _layernorm_dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
) = ctx
casted_grad, dbias = tex.quantize_dbias(
......@@ -325,16 +291,10 @@ def _layernorm_dense_bwd_rule(
)
# NT GEMM
sequence_dim = get_sequence_parallel_dim(
layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,)
)
dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......@@ -348,7 +308,6 @@ def _layernorm_dense_bwd_rule(
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -13,7 +13,6 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
import warnings
from typing import List, Tuple, Sequence, Union, Callable
from functools import partial
......@@ -29,19 +28,6 @@ from .quantize import (
noop_quantizer_set,
TensorUsage,
)
from .sharding import (
get_sequence_parallel_dim,
)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp(
......@@ -61,7 +47,6 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
......@@ -93,7 +78,6 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
......@@ -139,13 +123,12 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -165,7 +148,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets,
):
"""Internal implementation of layernorm_mlp with custom VJP.
......@@ -191,7 +173,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets
Returns:
......@@ -216,7 +197,6 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
)
return output
......@@ -241,7 +221,6 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
quantizer_sets,
):
"""Forward pass rule for layernorm_mlp.
......@@ -274,17 +253,6 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None
......@@ -312,7 +280,6 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
)
......@@ -337,16 +304,12 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in)
sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,))
dot_2_output = tex.gemm(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
)
if use_bias_2 and tex.gemm_uses_jax_dot():
......@@ -374,8 +337,6 @@ def _layernorm_mlp_fwd_rule(
use_bias_1,
use_bias_2,
quantizer_sets,
x_bdim,
sequence_dim,
)
return dot_2_output, ctx
......@@ -393,7 +354,6 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
batch_first,
ctx,
grad,
):
......@@ -410,7 +370,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
(
x,
mu,
......@@ -429,8 +389,6 @@ def _layernorm_mlp_bwd_rule(
use_bias_1,
use_bias_2,
quantizer_sets,
x_bdim,
sequence_dim,
) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
......@@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
......@@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule(
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -500,9 +456,6 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......@@ -513,7 +466,6 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
......@@ -86,30 +86,6 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis
def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims):
"""
Get the index for the sequence-parallel dimension based on the given logical axes.
The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting
dimension.
"""
if not logical_axes:
return None
pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True)
ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)]
lspecs = [pspec[i] for i in ldims if pspec[i] is not None]
if len(lspecs) == 0:
return None
assert len(lspecs) == 1, (
"Expected only 1 non-batched non-contracting dimension to be sharded for "
f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}"
)
return pspec.index(lspecs[0])
def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False):
"""
Convert logical axes to PartitionSpec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment