"transformer_engine/pytorch/quantization.py" did not exist on "d3d7ed2c89b72d4bc9169b1fc37306c3fea06df4"
Unverified Commit d75bf43f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] CollectiveGemm (#2166)



* init cgemm + unit tests

* UB bootstrap with NCCL, no MPI dependency

* add NVLINK-P2P check + error message

* skip tests if no NVLINK available

* use std::vector to store ncclComm_t

* update misuse of TP warning
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d145786
......@@ -11,6 +11,7 @@ customizable contracting dimensions for flexible tensor operations.
from typing import Tuple, Sequence
from functools import partial
import warnings
import jax
import jax.numpy as jnp
......@@ -62,10 +63,13 @@ def dense(
kernel: jnp.ndarray,
bias: jnp.ndarray = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
batch_sequence_transpose: bool = False,
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
output_axes: Tuple[str, ...] = None,
using_global_amax_of_x: bool = False,
collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -78,12 +82,20 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
output_axes: Logical axes for sharding the output
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
if batch_sequence_transpose:
warnings.warn("batch_sequence_transpose is not well tested, use with caution!")
if not get_quantize_config().is_fp8_enabled():
input_dtype = x.dtype
kernel = kernel.astype(input_dtype)
......@@ -93,32 +105,30 @@ def dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
quantizer_set,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
)
return output
@partial(
jax.custom_vjp,
nondiff_argnums=(
3,
4,
5,
7,
),
)
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
def _dense(
x,
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
quantizer_set,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set, # need to be a diff_arg for DelayedScaling state management
):
"""Internal implementation of dense layer transformation with custom VJP.
......@@ -130,10 +140,13 @@ def _dense(
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor.
input_axes: Logical axes for sharding the activation input
output_axes: Logical axes for sharding the output_axes
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False.
collective_op_set: A set of CollectiveOp objects for forward and backward passes.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
......@@ -143,10 +156,13 @@ def _dense(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
quantizer_set,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
)
return output
......@@ -156,10 +172,13 @@ def _dense_fwd_rule(
kernel,
bias,
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
quantizer_set,
output_axes,
using_global_amax_of_x,
collective_op_set,
quantizer_set,
):
"""Forward pass rule for dense layer transformation.
......@@ -202,9 +221,12 @@ def _dense_fwd_rule(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set.forward,
)
output = with_sharding_constraint_by_logical_axes(output, output_axes)
if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......@@ -223,8 +245,16 @@ def _dense_fwd_rule(
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad
): # pylint: disable=unused-argument
contracting_dims,
batch_sequence_transpose,
input_axes,
kernel_axes,
output_axes,
using_global_amax_of_x,
collective_op_set,
ctx,
grad,
):
"""Backward pass rule for dense layer transformation.
Returns:
......@@ -239,6 +269,7 @@ def _dense_bwd_rule(
quantizer_set,
flatten_axis_k,
) = ctx
grad = with_sharding_constraint_by_logical_axes(grad, output_axes)
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
......@@ -266,8 +297,9 @@ def _dense_bwd_rule(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
contracting_dims=(g_contracting_dim, k_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set.backward,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN
# x_non_contracting_dims
......@@ -279,7 +311,10 @@ def _dense_bwd_rule(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
contracting_dims=(x_contracting_dim, g_contracting_dim),
transpose_batch_sequence=batch_sequence_transpose,
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
return dgrad, wgrad, dbias, quantizer_set
......
......@@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[
return drop_path_shape
# TODO(Phuong): move this function to sharding.py
def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules:
"""
Extend the given Flax logical axis rules with the predefined TransformerLayer's
......
......@@ -41,6 +41,7 @@ def layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
batch_sequence_transpose: bool = False,
norm_input_axes: Tuple[str, ...] = None,
dot_1_input_axes: Tuple[str, ...] = None,
dot_2_input_axes: Tuple[str, ...] = None,
......@@ -49,6 +50,10 @@ def layernorm_mlp(
ffn1_ckpt_name: str = "ffn1",
ffn2_ckpt_name: str = "ffn2",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
collective_op_sets: Tuple[tex.CollectiveOpSet] = (
tex.noop_collective_op_set,
tex.noop_collective_op_set,
),
quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set),
) -> jnp.ndarray:
"""Apply layer normalization followed by MLP block.
......@@ -72,6 +77,7 @@ def layernorm_mlp(
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
......@@ -80,6 +86,7 @@ def layernorm_mlp(
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
......@@ -122,6 +129,7 @@ def layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -130,12 +138,13 @@ def layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
quantizer_sets,
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
def _layernorm_mlp(
x: jnp.ndarray,
gamma: jnp.ndarray,
......@@ -147,6 +156,7 @@ def _layernorm_mlp(
norm_type: str,
zero_centered_gamma: bool,
epsilon: float,
batch_sequence_transpose: bool,
norm_input_axes: Tuple[str, ...],
dot_1_input_axes: Tuple[str, ...],
dot_2_input_axes: Tuple[str, ...],
......@@ -155,6 +165,7 @@ def _layernorm_mlp(
ffn1_ckpt_name: str,
ffn2_ckpt_name: str,
activation_type: Sequence[Union[str, Callable]],
collective_op_sets: Tuple[tex.CollectiveOpSet],
quantizer_sets,
):
"""Internal implementation of layernorm_mlp with custom VJP.
......@@ -174,12 +185,16 @@ def _layernorm_mlp(
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
batch_sequence_transpose: Whether to transpose the batch and sequence dimensions
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
kernel_1_axes: Logical axes for first weight matrix sharding
kernel_2_axes: Logical axes for second weight matrix sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations
quantizer_sets: Tuple of quantizer sets
Returns:
......@@ -196,6 +211,7 @@ def _layernorm_mlp(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -204,6 +220,7 @@ def _layernorm_mlp(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
quantizer_sets,
)
return output
......@@ -220,6 +237,7 @@ def _layernorm_mlp_fwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -228,6 +246,7 @@ def _layernorm_mlp_fwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
quantizer_sets,
):
"""Forward pass rule for layernorm_mlp.
......@@ -247,6 +266,10 @@ def _layernorm_mlp_fwd_rule(
del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.forward.is_reduce_scatter
assert not collective_op_set_2.forward.is_all_gather
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len, intermediate)
......@@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel_1.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_1 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_1.forward,
)
if use_bias_1 and tex.gemm_uses_jax_dot():
......@@ -326,8 +351,10 @@ def _layernorm_mlp_fwd_rule(
casted_act_out.get_tensor(TensorUsage.LHS),
casted_kernel_2.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, k_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
bias=bias_2 if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False,
collective_op=collective_op_set_2.forward,
)
if use_bias_2 and tex.gemm_uses_jax_dot():
......@@ -335,6 +362,8 @@ def _layernorm_mlp_fwd_rule(
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
dot_2_output += jnp.reshape(bias_2, bias_2_new_shape)
# sharding of outputs should be the same as dot_1's input
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes)
dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name)
ctx = (
......@@ -364,6 +393,7 @@ def _layernorm_mlp_bwd_rule(
norm_type,
zero_centered_gamma,
epsilon,
batch_sequence_transpose,
norm_input_axes,
dot_1_input_axes,
dot_2_input_axes,
......@@ -372,6 +402,7 @@ def _layernorm_mlp_bwd_rule(
ffn1_ckpt_name,
ffn2_ckpt_name,
activation_type,
collective_op_sets,
ctx,
grad,
):
......@@ -410,6 +441,10 @@ def _layernorm_mlp_bwd_rule(
) = ctx
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
collective_op_set_1, collective_op_set_2 = collective_op_sets
assert not collective_op_set_1.backward.is_all_gather
assert not collective_op_set_2.backward.is_reduce_scatter
# Since the sharding of outputs should be the same as dot_1's input
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
......@@ -436,6 +471,8 @@ def _layernorm_mlp_bwd_rule(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel_2,
contracting_dims=(g_contracting_dims_2, k_contracting_dims_2),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_2.backward,
)
dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes)
......@@ -450,6 +487,7 @@ def _layernorm_mlp_bwd_rule(
casted_act_out,
casted_grad.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
)
wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes)
......@@ -476,6 +514,8 @@ def _layernorm_mlp_bwd_rule(
casted_dact_out.get_tensor(TensorUsage.LHS),
casted_kernel_1,
contracting_dims=(g_contracting_dims_1, k_contracting_dims_1),
transpose_batch_sequence=batch_sequence_transpose,
collective_op=collective_op_set_1.backward,
)
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes)
......@@ -486,6 +526,7 @@ def _layernorm_mlp_bwd_rule(
casted_ln_out,
casted_dact_out.get_tensor(TensorUsage.RHS),
contracting_dims=(x_contracting_dims, g_contracting_dims),
transpose_batch_sequence=batch_sequence_transpose,
)
wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes)
......
......@@ -13,6 +13,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Optional
import warnings
import jax
import jax.numpy as jnp
from jax.interpreters import pxla
......@@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
def tpsp_axis_size():
"""
Get the size of the tensor parallelism axis.
Return 1 if no TP axis is set.
"""
return get_mesh_axis_size(global_mesh_resource().tpsp_resource)
def dp_or_fsdp_axis_size():
"""
Get the size of the data parallelism or FSDP axis.
Return 1 if no DP/FSDP axis is set.
"""
dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource)
fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource)
return dp_size if dp_size > 1 else fsdp_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment