Unverified Commit cae1c436 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] TE Gemm custom call clean up (#2030)



* rm batch_dim, sequence_dim, sequence_parallel_output
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm lhs_quantized_colwise and rhs_quantized_colwise
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm unnecessary transpose_batch_sequence arg from some modules
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dd083bdf
...@@ -333,7 +333,6 @@ class TestDistributedLayernormMLP: ...@@ -333,7 +333,6 @@ class TestDistributedLayernormMLP:
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
use_bias=use_bias, use_bias=use_bias,
...@@ -352,7 +351,6 @@ class TestDistributedLayernormMLP: ...@@ -352,7 +351,6 @@ class TestDistributedLayernormMLP:
): ):
ln_mlp_sharded = LayerNormMLP( ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
scale_axes=LN_SCALE_AXES, scale_axes=LN_SCALE_AXES,
......
...@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi" name = "te_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) impl_static_args = (6, 7, 8, 9, 10, 11, 12)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -169,22 +169,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -169,22 +169,13 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator del use_split_accumulator
del (
sequence_parallel_output,
sequence_dim,
)
def _dims_are_consecutive(dims): def _dims_are_consecutive(dims):
if len(dims) <= 1: if len(dims) <= 1:
...@@ -207,27 +198,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -207,27 +198,6 @@ class GemmPrimitive(BasePrimitive):
f"{rhs_contracting_dims}." f"{rhs_contracting_dims}."
) )
(
lhs_batch_dims,
rhs_batch_dims,
) = map(sanitize_dims, operand_ndims, batched_dims)
assert _dims_are_consecutive(lhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f"{lhs_batch_dims}."
)
assert _dims_are_consecutive(rhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f"{rhs_batch_dims}."
)
if len(lhs_batch_dims) == 0:
assert (
len(rhs_batch_dims) == 0
), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif len(rhs_batch_dims) != 0:
assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
bdim in rhs_contracting_dims for bdim in rhs_batch_dims
), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size, rhs_contracting_size = map( lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape), (lhs.shape, rhs.shape),
...@@ -341,19 +311,13 @@ class GemmPrimitive(BasePrimitive): ...@@ -341,19 +311,13 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype del out_dtype
del sequence_parallel_output, sequence_dim
lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
...@@ -395,16 +359,11 @@ class GemmPrimitive(BasePrimitive): ...@@ -395,16 +359,11 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout( lhs_transposed, rhs_transposed = _get_gemm_layout(
...@@ -414,14 +373,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -414,14 +373,14 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_inv, lhs_scale_inv,
scaling_mode, scaling_mode,
lhs.shape, lhs.shape,
is_colwise=lhs_quantized_colwise, is_colwise=lhs_transposed,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
) )
rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv, rhs_scale_inv,
scaling_mode, scaling_mode,
rhs.shape, rhs.shape,
is_colwise=rhs_quantized_colwise, is_colwise=not rhs_transposed,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
) )
...@@ -434,55 +393,34 @@ class GemmPrimitive(BasePrimitive): ...@@ -434,55 +393,34 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
) )
return outputs[:-3] # discard workspace arrays return outputs[:-3] # discard workspace arrays
@staticmethod @staticmethod
def batcher( def batcher(
batched_args, batched_args,
jax_batch_dims, batch_dims,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
): ):
assert GemmPrimitive.outer_primitive is not None assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args lhs_bdims, _, rhs_bdims, *_ = batch_dims
lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
# Output is batched like the non-contracting batch dimensions of the LHS operand # Batched GEMM is not supported
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) assert (
lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) lhs_bdims is None and rhs_bdims is None
out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})"
out_bdims = (None,)
# Bias gradient is never batched # Bias gradient is never batched
bias_bdims = (None,) bias_bdims = (None,)
...@@ -497,16 +435,11 @@ class GemmPrimitive(BasePrimitive): ...@@ -497,16 +435,11 @@ class GemmPrimitive(BasePrimitive):
*batched_args, *batched_args,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
), ),
(out_bdims, bias_bdims, pre_gelu_bdims), (out_bdims, bias_bdims, pre_gelu_bdims),
) )
...@@ -515,11 +448,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -515,11 +448,7 @@ class GemmPrimitive(BasePrimitive):
def _parse_operand_output_specs( def _parse_operand_output_specs(
arg_infos, arg_infos,
contracting_dims, contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
): ):
del sequence_dim, sequence_parallel_output, batched_dims
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
...@@ -586,44 +515,30 @@ class GemmPrimitive(BasePrimitive): ...@@ -586,44 +515,30 @@ class GemmPrimitive(BasePrimitive):
(lhs_specs, rhs_specs, bias_specs, gelu_specs), (lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs),
reduce_spec, reduce_spec,
0,
) )
@staticmethod @staticmethod
def infer_sharding_from_operands( def infer_sharding_from_operands(
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del ( del (
out_dtype, out_dtype,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
grad, grad,
) )
del use_split_accumulator, result_infos del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( (_, (out_specs, dbias_specs, pre_gelu_specs), _) = (
GemmPrimitive._parse_operand_output_specs( GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
...@@ -643,16 +558,11 @@ class GemmPrimitive(BasePrimitive): ...@@ -643,16 +558,11 @@ class GemmPrimitive(BasePrimitive):
def partition( def partition(
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -663,14 +573,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -663,14 +573,7 @@ class GemmPrimitive(BasePrimitive):
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs), (out_specs, dbias_specs, pre_gelu_specs),
reduce_spec, reduce_spec,
_, ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims)
) = GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
batched_dims,
sequence_parallel_output,
sequence_dim,
)
# Assemble argument shardings # Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
...@@ -717,19 +620,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -717,19 +620,14 @@ class GemmPrimitive(BasePrimitive):
gelu_input, gelu_input,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
) )
# All-Reduce/Reduce-Scatter GEMM output # All-Reduce GEMM output
if reduce_spec is not None: if reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec) outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
...@@ -741,54 +639,42 @@ class GemmPrimitive(BasePrimitive): ...@@ -741,54 +639,42 @@ class GemmPrimitive(BasePrimitive):
def shardy_sharding_rule( def shardy_sharding_rule(
out_dtype, out_dtype,
contracting_dims, contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode, scaling_mode,
fuse_bias, fuse_bias,
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
sequence_parallel_output,
sequence_dim,
mesh, mesh,
operand_types, operand_types,
result_types, result_types,
): ):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator del out_dtype, grad, use_split_accumulator
del sequence_parallel_output, sequence_dim, mesh, result_types del mesh, result_types
prefix = "GemmPrimitive_" prefix = "GemmPrimitive_"
def _generate_operand_rules(name, ndim, cdims, bdims): def _generate_operand_rules(name, ndim, cdims):
specs = [] specs = []
ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) ldims = tuple(i for i in range(ndim) if i not in cdims)
for i in range(ndim): for i in range(ndim):
dim_name = None dim_name = None
if i in bdims: if i in cdims:
dim_idx = bdims.index(i) if len(bdims) > 1 else "" dim_idx = cdims.index(i)
dim_name = f"b{dim_idx}"
elif i in cdims:
dim_idx = cdims.index(i) if len(cdims) > 1 else ""
dim_name = f"k{dim_idx}" dim_name = f"k{dim_idx}"
else: else:
dim_idx = ldims.index(i) if len(ldims) > 1 else "" dim_idx = ldims.index(i)
dim_name = f"{name}_l{dim_idx}" dim_name = f"{name}_l{dim_idx}"
specs.append(prefix + dim_name) specs.append(prefix + dim_name)
return specs return specs
lhs, _, rhs, *_ = operand_types lhs, _, rhs, *_ = operand_types
operand_ndims = (len(lhs.shape), len(rhs.shape)) operand_ndims = (len(lhs.shape), len(rhs.shape))
(lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( (lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims)
lambda dims: map(sanitize_dims, operand_ndims, dims),
(contracting_dims, batched_dims),
)
lhs_specs, rhs_specs = map( lhs_specs, rhs_specs = map(
_generate_operand_rules, _generate_operand_rules,
("lhs", "rhs"), ("lhs", "rhs"),
operand_ndims, operand_ndims,
(lhs_cdims, rhs_cdims), (lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
) )
lhs_scale_specs = ("…1",) lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",) rhs_scale_specs = ("…2",)
...@@ -840,13 +726,10 @@ def _te_gemm( ...@@ -840,13 +726,10 @@ def _te_gemm(
lhs_quantizer: Quantizer = None, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
fuse_bias: bool = False, fuse_bias: bool = False,
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
sequence_parallel_output: bool = False,
sequence_dim: int = None,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
...@@ -857,7 +740,6 @@ def _te_gemm( ...@@ -857,7 +740,6 @@ def _te_gemm(
scaling_mode = ScalingMode.NO_SCALING scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
# Quantize operands (if necessary) # Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
...@@ -876,7 +758,6 @@ def _te_gemm( ...@@ -876,7 +758,6 @@ def _te_gemm(
lhs_scale_inv = lhs_q.scale_inv lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T": if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
if isinstance(rhs_q, ScaledTensor): if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
...@@ -894,7 +775,6 @@ def _te_gemm( ...@@ -894,7 +775,6 @@ def _te_gemm(
rhs_scale_inv = rhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T": if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
# Dummy empties for bias and gelu # Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
...@@ -912,16 +792,11 @@ def _te_gemm( ...@@ -912,16 +792,11 @@ def _te_gemm(
gelu_input, gelu_input,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims), contracting_dims=(lhs_cdims, rhs_cdims),
batched_dims=(lhs_bdims, rhs_bdims),
lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
sequence_parallel_output=sequence_parallel_output,
sequence_dim=sequence_dim,
) )
...@@ -1124,10 +999,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): ...@@ -1124,10 +999,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
if rhs.data_layout == "T": if rhs.data_layout == "T":
rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
...@@ -1239,7 +1112,6 @@ def gemm( ...@@ -1239,7 +1112,6 @@ def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
lhs_quantizer: Quantizer = None, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None,
**kwargs, **kwargs,
...@@ -1258,11 +1130,6 @@ def gemm( ...@@ -1258,11 +1130,6 @@ def gemm(
Object for down-casting the RHS operand for quantized GEMM. Object for down-casting the RHS operand for quantized GEMM.
contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
Tuple of sequences representing the contracting dimensions of the operands. Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM
call to avoid a potentially undesirable reduction in any batched contracting dimensions
when invoked with sharded operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM. with TE's custom call to cuBLAS GEMM.
...@@ -1282,15 +1149,6 @@ def gemm( ...@@ -1282,15 +1149,6 @@ def gemm(
Enable promoting some intermediate sums to higher precision when accumulating the result in Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only
supported with TE's custom call to cuBLAS GEMM. supported with TE's custom call to cuBLAS GEMM.
sequence_parallel_output: bool, default = False
Produces an output with the first non-batched non-contracting dimension sharded with the
same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum`
for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to
cuBLAS GEMM.
sequence_dim: int, default = None
Index of the sequence dimension for the LHS operand. This controls which dimension of the
GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first
non-batched non-contracting dimension is assumed to be the sequence dimension.
Returns Returns
------- -------
...@@ -1329,14 +1187,6 @@ def gemm( ...@@ -1329,14 +1187,6 @@ def gemm(
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled." "GEMM primitive is disabled."
) )
assert (
not kwargs.get("sequence_parallel_output", False)
and kwargs.get("sequence_dim", None) is None
), (
"TE GEMM was invoked with sequence-parallelism options that are not supported by the "
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
outputs = _te_gemm( outputs = _te_gemm(
...@@ -1345,7 +1195,6 @@ def gemm( ...@@ -1345,7 +1195,6 @@ def gemm(
lhs_quantizer=lhs_quantizer, lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer, rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
batched_dims=batched_dims,
**kwargs, **kwargs,
) )
......
...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation. ...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations. customizable contracting dimensions for flexible tensor operations.
""" """
import warnings
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import jax import jax
...@@ -22,17 +22,6 @@ from .quantize import ( ...@@ -22,17 +22,6 @@ from .quantize import (
TensorUsage, TensorUsage,
) )
from .sharding import get_sequence_parallel_dim
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense( def dense(
x: jnp.ndarray, x: jnp.ndarray,
...@@ -41,8 +30,6 @@ def dense( ...@@ -41,8 +30,6 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
sequence_parallel_output: bool = False,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -56,9 +43,6 @@ def dense( ...@@ -56,9 +43,6 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
...@@ -79,14 +63,19 @@ def dense( ...@@ -79,14 +63,19 @@ def dense(
contracting_dims, contracting_dims,
input_axes, input_axes,
kernel_axes, kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set, quantizer_set,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) @partial(
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
),
)
def _dense( def _dense(
x, x,
kernel, kernel,
...@@ -94,8 +83,6 @@ def _dense( ...@@ -94,8 +83,6 @@ def _dense(
contracting_dims, contracting_dims,
input_axes, input_axes,
kernel_axes, kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set, quantizer_set,
): ):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
...@@ -110,9 +97,6 @@ def _dense( ...@@ -110,9 +97,6 @@ def _dense(
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only
supported for TE custom GEMM with row-parallel kernel axes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
...@@ -125,8 +109,6 @@ def _dense( ...@@ -125,8 +109,6 @@ def _dense(
contracting_dims, contracting_dims,
input_axes, input_axes,
kernel_axes, kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -139,8 +121,6 @@ def _dense_fwd_rule( ...@@ -139,8 +121,6 @@ def _dense_fwd_rule(
contracting_dims, contracting_dims,
input_axes, input_axes,
kernel_axes, kernel_axes,
batch_first,
sequence_parallel_output,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -159,23 +139,6 @@ def _dense_fwd_rule( ...@@ -159,23 +139,6 @@ def _dense_fwd_rule(
not x_is_transposed and not k_is_transposed not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims) flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
...@@ -198,10 +161,8 @@ def _dense_fwd_rule( ...@@ -198,10 +161,8 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(),
) )
if use_bias and tex.gemm_uses_jax_dot(): if use_bias and tex.gemm_uses_jax_dot():
...@@ -216,13 +177,12 @@ def _dense_fwd_rule( ...@@ -216,13 +177,12 @@ def _dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) )
return output, ctx return output, ctx
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad contracting_dims, input_axes, kernel_axes, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
...@@ -237,7 +197,6 @@ def _dense_bwd_rule( ...@@ -237,7 +197,6 @@ def _dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) = ctx ) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map( fwd_x_contracting_dims, fwd_k_contracting_dims = map(
...@@ -262,21 +221,10 @@ def _dense_bwd_rule( ...@@ -262,21 +221,10 @@ def _dense_bwd_rule(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
# Get sequence-parallel dimension of the FWD input (if it exists)
sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,))
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=(
sequence_dim is not None
and not sequence_parallel_output
and not tex.gemm_uses_jax_dot()
),
sequence_dim=(
None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim
),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
...@@ -290,7 +238,6 @@ def _dense_bwd_rule( ...@@ -290,7 +238,6 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -15,12 +15,12 @@ from jax import lax ...@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from ..dense import dense, _issue_batch_first_warning as _dense_warning from ..dense import dense
from ..layernorm import canonicalize_norm_type from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning from ..layernorm_mlp import layernorm_mlp
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -273,10 +273,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -273,10 +273,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
...@@ -287,7 +283,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -287,7 +283,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",) bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
...@@ -414,17 +409,11 @@ class DenseGeneral(TransformerEngineBase): ...@@ -414,17 +409,11 @@ class DenseGeneral(TransformerEngineBase):
Indicate the logical axes of sharding constraint to the input, like Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint. sharding constraint.
sequence_parallel_output: bool, default = False
Produce a sequence-parallel output with the first non-batch dimension sharded over
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
...@@ -438,17 +427,9 @@ class DenseGeneral(TransformerEngineBase): ...@@ -438,17 +427,9 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
sequence_parallel_output: bool = False
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
...@@ -513,7 +494,6 @@ class DenseGeneral(TransformerEngineBase): ...@@ -513,7 +494,6 @@ class DenseGeneral(TransformerEngineBase):
input_axes=self.input_axes, input_axes=self.input_axes,
kernel_axes=self.kernel_axes, kernel_axes=self.kernel_axes,
quantizer_set=quantizer_set, quantizer_set=quantizer_set,
sequence_parallel_output=self.sequence_parallel_output,
) )
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
...@@ -631,10 +611,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -631,10 +611,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
depth_scaling: float, default = None depth_scaling: float, default = None
The factor to scale the output from `DenseGeneral`. It should be a float The factor to scale the output from `DenseGeneral`. It should be a float
value or None. When None is set, then no scaling is applied. value or None. When None is set, then no scaling is applied.
...@@ -660,18 +636,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -660,18 +636,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
depth_scaling: float = None depth_scaling: float = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, 1.0,
...@@ -949,10 +918,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -949,10 +918,6 @@ class LayerNormMLP(TransformerEngineBase):
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
""" """
intermediate_dim: int = 2048 intermediate_dim: int = 2048
...@@ -981,7 +946,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -981,7 +946,6 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
...@@ -989,12 +953,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -989,12 +953,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name: str = "ffn2" ffn2_ckpt_name: str = "ffn2"
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
...@@ -1167,7 +1167,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1167,7 +1167,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=(3, self.num_attention_heads * self.head_dim), features=(3, self.num_attention_heads * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.return_layernorm_output, return_layernorm_output=self.return_layernorm_output,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
...@@ -1194,7 +1193,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1194,7 +1193,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_attention_heads * self.head_dim, features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=(self.return_layernorm_output or is_self_attn), return_layernorm_output=(self.return_layernorm_output or is_self_attn),
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
...@@ -1219,7 +1217,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1219,7 +1217,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
kv_proj = DenseGeneral( kv_proj = DenseGeneral(
axis=-1, axis=-1,
features=(2, self.num_gqa_groups * self.head_dim), features=(2, self.num_gqa_groups * self.head_dim),
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_init=kv_init, kernel_init=kv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -1238,7 +1235,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1238,7 +1235,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
DenseGeneral, DenseGeneral,
axis=-1, axis=-1,
features=self.num_gqa_groups * self.head_dim, features=self.num_gqa_groups * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
kernel_axes=(W_FSDP_AXES, W_TP_AXES), kernel_axes=(W_FSDP_AXES, W_TP_AXES),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
...@@ -1255,7 +1251,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1255,7 +1251,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_attention_heads * self.head_dim, features=self.num_attention_heads * self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=True, return_layernorm_output=True,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
...@@ -1420,7 +1415,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1420,7 +1415,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
out = DenseGeneral( out = DenseGeneral(
features=inputs_q.shape[-1], features=inputs_q.shape[-1],
transpose_batch_sequence=self.transpose_batch_sequence,
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=(W_TP_AXES, W_FSDP_AXES), kernel_axes=(W_TP_AXES, W_FSDP_AXES),
...@@ -1432,7 +1426,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1432,7 +1426,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
name="out", name="out",
sequence_parallel_output=self.enable_sequence_parallel,
)(x) )(x)
out = checkpoint_name(out, "out_proj") out = checkpoint_name(out, "out_proj")
...@@ -2023,7 +2016,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2023,7 +2016,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
transpose_batch_sequence=self.transpose_batch_sequence,
return_layernorm_output=self.apply_residual_connection_post_layernorm, return_layernorm_output=self.apply_residual_connection_post_layernorm,
intermediate_dim=self.mlp_hidden_size, intermediate_dim=self.mlp_hidden_size,
activations=self.mlp_activations, activations=self.mlp_activations,
...@@ -2078,7 +2070,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2078,7 +2070,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
name="output_layernorm", name="output_layernorm",
)(z) )(z)
......
...@@ -9,7 +9,6 @@ architectures. It supports various normalization types, quantization, and ...@@ -9,7 +9,6 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints. distributed training through sharding constraints.
""" """
import warnings
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
...@@ -24,17 +23,6 @@ from .quantize import ( ...@@ -24,17 +23,6 @@ from .quantize import (
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
TensorUsage, TensorUsage,
) )
from .sharding import get_sequence_parallel_dim
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense( def layernorm_dense(
...@@ -49,7 +37,6 @@ def layernorm_dense( ...@@ -49,7 +37,6 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -70,7 +57,6 @@ def layernorm_dense( ...@@ -70,7 +57,6 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -94,7 +80,6 @@ def layernorm_dense( ...@@ -94,7 +80,6 @@ def layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -109,7 +94,6 @@ def layernorm_dense( ...@@ -109,7 +94,6 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -124,7 +108,6 @@ def _layernorm_dense( ...@@ -124,7 +108,6 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -144,7 +127,6 @@ def _layernorm_dense( ...@@ -144,7 +127,6 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
Returns: Returns:
...@@ -162,7 +144,6 @@ def _layernorm_dense( ...@@ -162,7 +144,6 @@ def _layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -180,7 +161,6 @@ def _layernorm_dense_fwd_rule( ...@@ -180,7 +161,6 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -198,17 +178,6 @@ def _layernorm_dense_fwd_rule( ...@@ -198,17 +178,6 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd( casted_ln_out, mu, rsigma = tex.normalization_fwd(
...@@ -237,7 +206,6 @@ def _layernorm_dense_fwd_rule( ...@@ -237,7 +206,6 @@ def _layernorm_dense_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -261,7 +229,6 @@ def _layernorm_dense_fwd_rule( ...@@ -261,7 +229,6 @@ def _layernorm_dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) )
return output, ctx return output, ctx
...@@ -272,9 +239,8 @@ def _layernorm_dense_bwd_rule( ...@@ -272,9 +239,8 @@ def _layernorm_dense_bwd_rule(
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes,
kernel_axes, kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx, ctx,
grad, grad,
): ):
...@@ -289,6 +255,7 @@ def _layernorm_dense_bwd_rule( ...@@ -289,6 +255,7 @@ def _layernorm_dense_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del dot_input_axes
( (
casted_ln_out, casted_ln_out,
casted_kernel, casted_kernel,
...@@ -304,7 +271,6 @@ def _layernorm_dense_bwd_rule( ...@@ -304,7 +271,6 @@ def _layernorm_dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
...@@ -325,16 +291,10 @@ def _layernorm_dense_bwd_rule( ...@@ -325,16 +291,10 @@ def _layernorm_dense_bwd_rule(
) )
# NT GEMM # NT GEMM
sequence_dim = get_sequence_parallel_dim(
layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,)
)
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
contracting_dims=(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -348,7 +308,6 @@ def _layernorm_dense_bwd_rule( ...@@ -348,7 +308,6 @@ def _layernorm_dense_bwd_rule(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -13,7 +13,6 @@ The implementation supports various normalization types, activation functions, ...@@ -13,7 +13,6 @@ The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints. quantization, and distributed training through sharding constraints.
""" """
import warnings
from typing import List, Tuple, Sequence, Union, Callable from typing import List, Tuple, Sequence, Union, Callable
from functools import partial from functools import partial
...@@ -29,19 +28,6 @@ from .quantize import ( ...@@ -29,19 +28,6 @@ from .quantize import (
noop_quantizer_set, noop_quantizer_set,
TensorUsage, TensorUsage,
) )
from .sharding import (
get_sequence_parallel_dim,
)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_mlp( def layernorm_mlp(
...@@ -61,7 +47,6 @@ def layernorm_mlp( ...@@ -61,7 +47,6 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
batch_first: bool = True,
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -93,7 +78,6 @@ def layernorm_mlp( ...@@ -93,7 +78,6 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -139,13 +123,12 @@ def layernorm_mlp( ...@@ -139,13 +123,12 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -165,7 +148,6 @@ def _layernorm_mlp( ...@@ -165,7 +148,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
batch_first: bool,
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -191,7 +173,6 @@ def _layernorm_mlp( ...@@ -191,7 +173,6 @@ def _layernorm_mlp(
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
batch_first: Assume that X is batched in the first dimension.
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -216,7 +197,6 @@ def _layernorm_mlp( ...@@ -216,7 +197,6 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -241,7 +221,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -241,7 +221,6 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -274,17 +253,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -274,17 +253,6 @@ def _layernorm_mlp_fwd_rule(
assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_mlp()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
use_bias_1 = bias_1 is not None use_bias_1 = bias_1 is not None
use_bias_2 = bias_1 is not None use_bias_2 = bias_1 is not None
...@@ -312,7 +280,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -312,7 +280,6 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
) )
...@@ -337,16 +304,12 @@ def _layernorm_mlp_fwd_rule( ...@@ -337,16 +304,12 @@ def _layernorm_mlp_fwd_rule(
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in) # (batch..., hidden_in) x (hidden_out, hidden_in)
sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,))
dot_2_output = tex.gemm( dot_2_output = tex.gemm(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
) )
if use_bias_2 and tex.gemm_uses_jax_dot(): if use_bias_2 and tex.gemm_uses_jax_dot():
...@@ -374,8 +337,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -374,8 +337,6 @@ def _layernorm_mlp_fwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
sequence_dim,
) )
return dot_2_output, ctx return dot_2_output, ctx
...@@ -393,7 +354,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -393,7 +354,6 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
batch_first,
ctx, ctx,
grad, grad,
): ):
...@@ -410,7 +370,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -410,7 +370,7 @@ def _layernorm_mlp_bwd_rule(
Returns: Returns:
Tuple of gradients for all input parameters Tuple of gradients for all input parameters
""" """
del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name
( (
x, x,
mu, mu,
...@@ -429,8 +389,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -429,8 +389,6 @@ def _layernorm_mlp_bwd_rule(
use_bias_1, use_bias_1,
use_bias_2, use_bias_2,
quantizer_sets, quantizer_sets,
x_bdim,
sequence_dim,
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
batched_dims=((x_bdim,), ()),
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -500,9 +456,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -500,9 +456,6 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
batched_dims=((x_bdim,), ()),
sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(),
sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None,
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -513,7 +466,6 @@ def _layernorm_mlp_bwd_rule( ...@@ -513,7 +466,6 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -86,30 +86,6 @@ def get_sharding_map_logic_axis_to_mesh_axis(): ...@@ -86,30 +86,6 @@ def get_sharding_map_logic_axis_to_mesh_axis():
return te_logical_axis_to_mesh_axis return te_logical_axis_to_mesh_axis
def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims):
"""
Get the index for the sequence-parallel dimension based on the given logical axes.
The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting
dimension.
"""
if not logical_axes:
return None
pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True)
ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)]
lspecs = [pspec[i] for i in ldims if pspec[i] is not None]
if len(lspecs) == 0:
return None
assert len(lspecs) == 1, (
"Expected only 1 non-batched non-contracting dimension to be sharded for "
f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}"
)
return pspec.index(lspecs[0])
def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False):
""" """
Convert logical axes to PartitionSpec Convert logical axes to PartitionSpec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment