"git@developer.sourcefind.cn:liming6/sshd-tool.git" did not exist on "d198770e96a6eac4d7b6233e6f411e339b32ce3d"
Unverified Commit d75bf43f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] CollectiveGemm (#2166)



* init cgemm + unit tests

* UB bootstrap with NCCL, no MPI dependency

* add NVLINK-P2P check + error message

* skip tests if no NVLINK available

* use std::vector to store ncclComm_t

* update misuse of TP warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d145786
...@@ -11,6 +11,7 @@ customizable contracting dimensions for flexible tensor operations. ...@@ -11,6 +11,7 @@ customizable contracting dimensions for flexible tensor operations.
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -62,10 +63,13 @@ def dense( ...@@ -62,10 +63,13 @@ def dense(
kernel: jnp.ndarray, kernel: jnp.ndarray,
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False,
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set, output_axes: Tuple[str, ...] = None,
using_global_amax_of_x: bool = False, using_global_amax_of_x: bool = False,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -78,12 +82,20 @@ def dense( ...@@ -78,12 +82,20 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
quantizer_set: QuantizerSet which contains quantizers for different tensor types batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
if batch_sequence_transpose:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled(): if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype input_dtype = x.dtype
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
...@@ -93,32 +105,30 @@ def dense( ...@@ -93,32 +105,30 @@ def dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, output_axes,
using_global_amax_of_x, using_global_amax_of_x,
collective_op_set,
quantizer_set,
) )
return output return output
@partial( @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
7,
),
)
def _dense( def _dense(
x, x,
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, output_axes,
using_global_amax_of_x, using_global_amax_of_x,
collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management
): ):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
...@@ -130,10 +140,13 @@ def _dense( ...@@ -130,10 +140,13 @@ def _dense(
kernel: Weight matrix kernel: Weight matrix
bias: Optional bias tensor bias: Optional bias tensor
contracting_dims: Contracting dimensions specification contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
...@@ -143,10 +156,13 @@ def _dense( ...@@ -143,10 +156,13 @@ def _dense(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, output_axes,
using_global_amax_of_x, using_global_amax_of_x,
collective_op_set,
quantizer_set,
) )
return output return output
...@@ -156,10 +172,13 @@ def _dense_fwd_rule( ...@@ -156,10 +172,13 @@ def _dense_fwd_rule(
kernel, kernel,
bias, bias,
contracting_dims, contracting_dims,
batch_sequence_transpose,
input_axes, input_axes,
kernel_axes, kernel_axes,
quantizer_set, output_axes,
using_global_amax_of_x, using_global_amax_of_x,
collective_op_set,
quantizer_set,
): ):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
...@@ -202,9 +221,12 @@ def _dense_fwd_rule( ...@@ -202,9 +221,12 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias if not tex.gemm_uses_jax_dot() else None, bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
) )
output = with_sharding_constraint_by_logical_axes(output, output_axes)
if use_bias and tex.gemm_uses_jax_dot(): if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
...@@ -223,8 +245,16 @@ def _dense_fwd_rule( ...@@ -223,8 +245,16 @@ def _dense_fwd_rule(
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad contracting_dims,
): # pylint: disable=unused-argument batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
ctx,
grad,
):
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Returns: Returns:
...@@ -239,6 +269,7 @@ def _dense_bwd_rule( ...@@ -239,6 +269,7 @@ def _dense_bwd_rule(
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
) = ctx ) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
fwd_x_contracting_dims, fwd_k_contracting_dims = map( fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
...@@ -266,8 +297,9 @@ def _dense_bwd_rule( ...@@ -266,8 +297,9 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set.backward,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
...@@ -279,7 +311,10 @@ def _dense_bwd_rule( ...@@ -279,7 +311,10 @@ def _dense_bwd_rule(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set return dgrad, wgrad, dbias, quantizer_set
......
...@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ ...@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
return drop_path_shape return drop_path_shape
# TODO(Phuong): move this function to sharding.py
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
""" """
Extend the given Flax logical axis rules with the predefined TransformerLayer's Extend the given Flax logical axis rules with the predefined TransformerLayer's
......
...@@ -41,6 +41,7 @@ def layernorm_mlp( ...@@ -41,6 +41,7 @@ def layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
batch_sequence_transpose: bool = False,
norm_input_axes: Tuple[str, ...] = None, norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None,
...@@ -49,6 +50,10 @@ def layernorm_mlp( ...@@ -49,6 +50,10 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1", ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2", ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
),
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block. """Apply layer normalization followed by MLP block.
...@@ -72,6 +77,7 @@ def layernorm_mlp( ...@@ -72,6 +77,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm") norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication
...@@ -80,6 +86,7 @@ def layernorm_mlp( ...@@ -80,6 +86,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation activation_type: Activation function(s) to apply after the first dense layer transformation
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns: Returns:
...@@ -122,6 +129,7 @@ def layernorm_mlp( ...@@ -122,6 +129,7 @@ def layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -130,12 +138,13 @@ def layernorm_mlp( ...@@ -130,12 +138,13 @@ def layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _layernorm_mlp( def _layernorm_mlp(
x: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
...@@ -147,6 +156,7 @@ def _layernorm_mlp( ...@@ -147,6 +156,7 @@ def _layernorm_mlp(
norm_type: str, norm_type: str,
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
batch_sequence_transpose: bool,
norm_input_axes: Tuple[str, ...], norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...],
...@@ -155,6 +165,7 @@ def _layernorm_mlp( ...@@ -155,6 +165,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str, ffn1_ckpt_name: str,
ffn2_ckpt_name: str, ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets, quantizer_sets,
): ):
"""Internal implementation of layernorm_mlp with custom VJP. """Internal implementation of layernorm_mlp with custom VJP.
...@@ -174,12 +185,16 @@ def _layernorm_mlp( ...@@ -174,12 +185,16 @@ def _layernorm_mlp(
norm_type: Type of normalization norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding
kernel_1_axes: Logical axes for first weight matrix sharding
kernel_2_axes: Logical axes for second weight matrix sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s) activation_type: Activation function(s)
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of quantizer sets quantizer_sets: Tuple of quantizer sets
Returns: Returns:
...@@ -196,6 +211,7 @@ def _layernorm_mlp( ...@@ -196,6 +211,7 @@ def _layernorm_mlp(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -204,6 +220,7 @@ def _layernorm_mlp( ...@@ -204,6 +220,7 @@ def _layernorm_mlp(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
quantizer_sets, quantizer_sets,
) )
return output return output
...@@ -220,6 +237,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -220,6 +237,7 @@ def _layernorm_mlp_fwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -228,6 +246,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -228,6 +246,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
quantizer_sets, quantizer_sets,
): ):
"""Forward pass rule for layernorm_mlp. """Forward pass rule for layernorm_mlp.
...@@ -247,6 +266,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -247,6 +266,10 @@ def _layernorm_mlp_fwd_rule(
del kernel_1_axes, kernel_2_axes del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather
# x should be in shape of (batch..., hidden) # x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
...@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS), casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None, bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
) )
if use_bias_1 and tex.gemm_uses_jax_dot(): if use_bias_1 and tex.gemm_uses_jax_dot():
...@@ -326,8 +351,10 @@ def _layernorm_mlp_fwd_rule( ...@@ -326,8 +351,10 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS), casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS), casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None, bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
) )
if use_bias_2 and tex.gemm_uses_jax_dot(): if use_bias_2 and tex.gemm_uses_jax_dot():
...@@ -335,6 +362,8 @@ def _layernorm_mlp_fwd_rule( ...@@ -335,6 +362,8 @@ def _layernorm_mlp_fwd_rule(
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
# sharding of outputs should be the same as dot_1's input
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = ( ctx = (
...@@ -364,6 +393,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -364,6 +393,7 @@ def _layernorm_mlp_bwd_rule(
norm_type, norm_type,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
batch_sequence_transpose,
norm_input_axes, norm_input_axes,
dot_1_input_axes, dot_1_input_axes,
dot_2_input_axes, dot_2_input_axes,
...@@ -372,6 +402,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -372,6 +402,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name, ffn1_ckpt_name,
ffn2_ckpt_name, ffn2_ckpt_name,
activation_type, activation_type,
collective_op_sets,
ctx, ctx,
grad, grad,
): ):
...@@ -410,6 +441,10 @@ def _layernorm_mlp_bwd_rule( ...@@ -410,6 +441,10 @@ def _layernorm_mlp_bwd_rule(
) = ctx ) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.backward.is_all_gather
assert not collective_op_set_2.backward.is_reduce_scatter
# Since the sharding of outputs should be the same as dot_1's input # Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
...@@ -436,6 +471,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -436,6 +471,8 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2, casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_2.backward,
) )
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
...@@ -450,6 +487,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -450,6 +487,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out, casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
) )
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
...@@ -476,6 +514,8 @@ def _layernorm_mlp_bwd_rule( ...@@ -476,6 +514,8 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS), casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1, casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_1.backward,
) )
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
...@@ -486,6 +526,7 @@ def _layernorm_mlp_bwd_rule( ...@@ -486,6 +526,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out, casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS), casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims), contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
) )
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
...@@ -13,6 +13,7 @@ from contextlib import contextmanager ...@@ -13,6 +13,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional from typing import Callable, Optional
import warnings import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.interpreters import pxla from jax.interpreters import pxla
...@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes ...@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource: if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh) x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x return x
def tpsp_axis_size():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return get_mesh_axis_size(global_mesh_resource().tpsp_resource)
def dp_or_fsdp_axis_size():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
return dp_size if dp_size > 1 else fsdp_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment