Unverified Commit b8ba734e authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

[Paddle] Add parallel support (#357)



* [Paddle] Add TP, DP, PP, FSDP
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Minor fix
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI failure
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Remove set_nccl_overlap_warning_if_tp
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Improve variable naming
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Refactor FP8 Buffer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Stylic changes
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix FP32 parallel training
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix numel performance issue
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Squashed commit of the following:

commit 79e2e5fd774e67dcdda9aae01a9f31a6479c5d70
Author: Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Date:   Sun Aug 20 14:39:16 2023 +0000

    Add TP test
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

commit 1d40ad60540490f97ed82ba877cc6eda8902cbf6
Author: Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Date:   Sun Aug 20 14:22:25 2023 +0000

    Fix tp_size when disabled
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

commit 6632f735a0c8251862355fc74622af59fae3a509
Author: Tian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Date:   Sun Aug 20 05:52:18 2023 +0000

    Add TP for attention and transformer layer
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add shape check
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add FSDP check for stage 1,2,3
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Review changes
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix group_sharding test
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Support NVTE_FUSE_ATTN
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix CI errors
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6aa1fcc8
......@@ -4,7 +4,7 @@
"""LayerNormLinear API"""
import os
from typing import Union, Tuple, Dict, Any
from typing import Union, Tuple, Dict, Any, Optional
import paddle
import paddle.nn.functional as F
......@@ -21,9 +21,22 @@ from ..cpp_extensions import (
from .base import TransformerEngineBaseLayer
from .linear import _linear_fwd, _linear_bwd
from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors
from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
from ..distributed import (
allreduce,
get_tp_group_and_world_size,
identity,
track_rng_state,
set_tensor_dist_attr,
set_weight_tensor_dist_attr,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import cast_if_needed, cast_if_needed_inplace, assert_dim_for_fp8_forward_exec
from ..utils import (
assert_dim_for_fp8_forward_exec,
cast_if_needed,
cast_if_needed_inplace,
divide,
)
__all__ = ["LayerNormLinear", "_layernorm_fwd_fp8_cast", "_layernorm_bwd"]
......@@ -128,9 +141,13 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features))
if fp8_enabled:
......@@ -169,6 +186,9 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
fp8_calibration,
fp8_meta,
activation_dtype,
parallel_mode,
tensor_parallel,
tp_group,
is_grad_enabled,
)
......@@ -192,6 +212,10 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.parallel_mode = parallel_mode
ctx.tensor_parallel = tensor_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient
......@@ -208,6 +232,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
...]) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
name="_LayerNormLinear"):
(
inputmat,
......@@ -262,6 +288,9 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.fp8_meta,
True, # Always compute dgrad to feed into LayerNorm bwd
ctx.activation_dtype,
ctx.parallel_mode,
ctx.tensor_parallel,
ctx.tp_group,
)
if not ctx.fp8_enabled:
......@@ -307,6 +336,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
parallel_mode: Optional[str] = None,
tp_group: Union[dist_group_type, None] = None,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -322,9 +353,23 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=parallel_mode
is not None)
self.tensor_parallel = self.tp_size > 1
self.parallel_mode = parallel_mode
assert (self.parallel_mode
in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[in_features],
shape=[self.in_features],
attr=paddle.ParamAttr(initializer=Constant(
value=0.0 if self.zero_centered_gamma else 1.0)),
dtype=self._dtype,
......@@ -332,34 +377,48 @@ class LayerNormLinear(TransformerEngineBaseLayer):
)
self.ln_bias = self.create_parameter(
shape=[in_features],
shape=[self.in_features],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
# Linear weights
# Initialize Linear weight parameter
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
self.weight = self.create_parameter(
shape=[out_features, in_features]
if self.backend == 'transformer_engine' else [in_features, out_features],
shape=[self.out_features, self.in_features]
if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
self.backend)
# Initialize Linear bias parameter
self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True
if self.has_bias:
self.bias = self.create_parameter(
shape=[out_features],
shape=[self.out_features],
attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
if parallel_mode == "column":
set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
else:
self.bias = None
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
self.gemm_bias_fused_add = False
else:
self.gemm_bias_fused_add = True
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
......@@ -385,8 +444,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.ln_weight,
self.ln_bias,
self.weight,
self.bias,
self.has_bias,
self.bias if self.gemm_bias_fused_add else None,
self.has_bias and self.gemm_bias_fused_add,
self.eps,
self.fp8_enabled,
self.fp8_calibration,
......@@ -397,10 +456,19 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.parallel_mode,
self.tensor_parallel,
self.tp_group,
self.tp_size,
)
if self.return_layernorm_output:
out, ln_out = out
if not self.gemm_bias_fused_add:
out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
......@@ -418,7 +486,12 @@ class LayerNormLinear(TransformerEngineBaseLayer):
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
out = F.linear(ln_out, self.weight, self.bias)
if self.parallel_mode == 'column' and self.tensor_parallel:
ln_out = identity(ln_out, self.tp_group)
out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
out = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
if self.return_layernorm_output:
return out, ln_out
return out
......
......@@ -4,25 +4,38 @@
"""LayerNormMLP API"""
import os
from typing import Union, Tuple, Dict, Any
from typing import Union, Tuple, Dict, Any, Optional
import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant
from .base import TransformerEngineBaseLayer
from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd
from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8
from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type
from ..cpp_extensions import (
cast_from_fp8,
dgelu_cast_transpose_bgrad_fp8,
gelu_fp8,
transpose,
)
from .base import TransformerEngineBaseLayer
from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd
from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8
from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors
from ..distributed import (
allreduce,
get_tp_group_and_world_size,
identity,
track_rng_state,
set_tensor_dist_attr,
set_weight_tensor_dist_attr,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import cast_if_needed, assert_dim_for_fp8_forward_exec, get_paddle_act_func
from ..utils import (
assert_dim_for_fp8_forward_exec,
cast_if_needed,
cast_if_needed_inplace,
divide,
get_paddle_act_func,
)
__all__ = ["LayerNormMLP"]
......@@ -43,7 +56,11 @@ def _mlp_forward(
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
activation: str,
is_grad_enabled: bool,
set_parallel_mode: bool,
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
):
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -56,6 +73,9 @@ def _mlp_forward(
use_fc1_bias,
fp8_meta,
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
tp_group,
is_grad_enabled,
)
......@@ -75,6 +95,9 @@ def _mlp_forward(
use_fc2_bias,
fp8_meta,
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
tp_group,
is_grad_enabled,
)
else:
......@@ -88,7 +111,10 @@ def _mlp_forward(
fp8_calibration,
fp8_meta,
activation_dtype,
activation='gelu',
'column' if set_parallel_mode else None,
tensor_parallel,
tp_group,
activation=activation,
)
fc2_out = _linear_fwd_non_fp8(
......@@ -101,6 +127,9 @@ def _mlp_forward(
fp8_calibration,
fp8_meta,
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
tp_group,
)
return (
fc1_out,
......@@ -136,6 +165,9 @@ def _mlp_backward(
requires_dgrad: bool,
activation_dtype: paddle.dtype,
activation: str,
set_parallel_mode: bool,
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
):
(
fc1_dgrad,
......@@ -179,6 +211,9 @@ def _mlp_backward(
True,
requires_fc2_wgrad,
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
tp_group,
)
# GELU Bwd
......@@ -193,7 +228,7 @@ def _mlp_backward(
if requires_fc1_bgrad:
fc1_bgrad = fc1_bgrad_
# FC2 Bwd
# FC1 Bwd
requires_fc1_wgrad = not fc1_weight.stop_gradient
dgelu_no_fp8, fc1_input_no_fp8, fc1_input_t = None, None, None
if requires_fc1_wgrad:
......@@ -231,6 +266,9 @@ def _mlp_backward(
requires_dgrad,
requires_fc1_wgrad,
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
tp_group,
)
else:
dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8(
......@@ -240,6 +278,9 @@ def _mlp_backward(
requires_fc2_bgrad,
True,
activation_dtype,
'row' if set_parallel_mode else None,
tensor_parallel,
tp_group,
gelu_input=fc1_out,
activation=activation,
)
......@@ -250,6 +291,9 @@ def _mlp_backward(
requires_fc1_bgrad,
requires_dgrad,
activation_dtype,
'column' if set_parallel_mode else None,
tensor_parallel,
tp_group,
)
return (
fc1_dgrad,
......@@ -286,9 +330,13 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
set_parallel_mode: bool,
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmat = inp.reshape((-1, in_features))
if fp8_enabled:
......@@ -341,7 +389,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fp8_calibration,
fp8_meta,
activation_dtype,
activation,
is_grad_enabled,
set_parallel_mode,
tensor_parallel,
tp_group,
)
if is_grad_enabled:
......@@ -369,6 +421,10 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.set_parallel_mode = set_parallel_mode
ctx.tensor_parallel = tensor_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient
ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient
......@@ -387,6 +443,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
...]) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
name="_LayerNormMLP"):
(
inputmat,
......@@ -442,6 +500,9 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
True,
ctx.activation_dtype,
ctx.activation,
ctx.set_parallel_mode,
ctx.tensor_parallel,
ctx.tp_group,
)
if not ctx.fp8_enabled:
# fc2_bias is fused with gemm for non-FP8 path
......@@ -491,6 +552,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
activation: str = "gelu",
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -507,6 +570,17 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=set_parallel_mode)
self.tensor_parallel = self.tp_size > 1
self.set_parallel_mode = set_parallel_mode
if self.set_parallel_mode:
self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size)
else:
self.size_per_partition = self.ffn_hidden_size
# LayerNorm weights
self.ln_weight = self.create_parameter(
shape=[self.hidden_size],
......@@ -524,36 +598,47 @@ class LayerNormMLP(TransformerEngineBaseLayer):
)
# FC1 weights
with track_rng_state(enable=self.tensor_parallel):
self.fc1_weight = self.create_parameter(
shape=[self.ffn_hidden_size, self.hidden_size]
if self.backend == 'transformer_engine' else [self.hidden_size, self.ffn_hidden_size],
shape=[self.size_per_partition, self.hidden_size] if self.backend
== 'transformer_engine' else [self.hidden_size, self.size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.fc1_weight,
self.tensor_parallel,
parallel_mode='column',
backend=self.backend)
self.has_bias = self._bias_attr is not False
if self._bias_attr is None or self._bias_attr is True:
use_default_bias = self._bias_attr is None or self._bias_attr is True
if use_default_bias:
self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0))
if self.has_bias:
self.fc1_bias = self.create_parameter(
shape=[self.ffn_hidden_size],
shape=[self.size_per_partition],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
)
set_tensor_dist_attr(self.fc1_bias, self.tensor_parallel, axis=0)
else:
self.fc1_bias = None
# FC2 weights
self.fc2_weight = self.create_parameter(
shape=[self.hidden_size, self.ffn_hidden_size]
if self.backend == 'transformer_engine' else [self.ffn_hidden_size, self.hidden_size],
shape=[self.hidden_size, self.size_per_partition] if self.backend
== 'transformer_engine' else [self.size_per_partition, self.hidden_size],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.fc2_weight,
self.tensor_parallel,
parallel_mode='row',
backend=self.backend)
if self.has_bias:
self.fc2_bias = self.create_parameter(
......@@ -565,6 +650,13 @@ class LayerNormMLP(TransformerEngineBaseLayer):
else:
self.fc2_bias = None
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.set_parallel_mode and self.tensor_parallel and self.has_bias:
self.gemm_bias_fused_add = False
else:
self.gemm_bias_fused_add = True
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
......@@ -606,12 +698,20 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.set_parallel_mode,
self.tensor_parallel,
self.tp_group,
self.tp_size,
)
if self.return_layernorm_output:
out, ln_out = out
return out, ln_out
if not self.gemm_bias_fused_add:
out = out + cast_if_needed_inplace(self.fc2_bias, self.activation_dtype)
if self.return_layernorm_output:
return out, ln_out
return out
def _pd_forward(
......@@ -628,11 +728,16 @@ class LayerNormMLP(TransformerEngineBaseLayer):
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
if self.set_parallel_mode and self.tensor_parallel:
ln_out = identity(ln_out, self.tp_group)
fc1_out = F.linear(ln_out, self.fc1_weight, self.fc1_bias)
act_func = get_paddle_act_func(self.activation)
act_out = act_func(fc1_out)
out = F.linear(act_out, self.fc2_weight, self.fc2_bias)
out = F.linear(act_out, self.fc2_weight,
self.fc2_bias if self.gemm_bias_fused_add else None)
if self.set_parallel_mode and self.tensor_parallel:
out = allreduce(out, self.tp_group)
out = out + self.fc2_bias if self.fc2_bias is not None else out
if self.return_layernorm_output:
return out, ln_out
return out
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Linear API"""
from typing import Union, Tuple, Dict, Any
from typing import Union, Tuple, Dict, Any, Optional
import paddle
import paddle.nn.functional as F
......@@ -17,13 +17,22 @@ from .base import (
_2X_ACC_WGRAD,
)
from ..fp8 import get_fp8_te_dtype
from ..constants import FP8FwdTensors, FP8BwdTensors
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose
from ..distributed import (
allreduce,
get_tp_group_and_world_size,
identity,
track_rng_state,
set_tensor_dist_attr,
set_weight_tensor_dist_attr,
)
from ..fp8 import get_fp8_te_dtype
from ..utils import (
assert_dim_for_fp8_forward_exec,
cast_if_needed,
cast_if_needed_inplace,
assert_dim_for_fp8_forward_exec,
divide,
get_bias_dtype,
)
......@@ -39,12 +48,15 @@ def _linear_fwd_fp8(
use_bias: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
):
"""FP8 path of Linear Fwd"""
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
bias_dtype = get_bias_dtype(activation_dtype)
bias = cast_if_needed_inplace(bias, bias_dtype)
bias = cast_if_needed(bias, bias_dtype)
if is_grad_enabled:
weight_fp8, weight_t_fp8 = cast_transpose(
......@@ -78,6 +90,10 @@ def _linear_fwd_fp8(
use_split_accumulator=_2X_ACC_FPROP,
)
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel:
out = allreduce(out, tp_group)
return out, weight_t_fp8
......@@ -91,6 +107,9 @@ def _linear_fwd_non_fp8(
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
activation: str = "",
):
"""Non-FP8 path of Linear Fwd"""
......@@ -123,6 +142,9 @@ def _linear_fwd_non_fp8(
return out, gelu_out
out, _, _ = outputs
# Row Parallel Linear
if parallel_mode == "row" and tensor_parallel:
out = allreduce(out, tp_group)
return out
......@@ -137,6 +159,9 @@ def _linear_fwd(
fp8_calibration: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
is_grad_enabled: bool,
):
if fp8_enabled:
......@@ -149,6 +174,9 @@ def _linear_fwd(
use_bias,
fp8_meta,
activation_dtype,
parallel_mode,
tensor_parallel,
tp_group,
is_grad_enabled,
)
else:
......@@ -162,6 +190,9 @@ def _linear_fwd(
fp8_calibration,
fp8_meta,
activation_dtype,
parallel_mode,
tensor_parallel,
tp_group,
)
return (
out,
......@@ -184,6 +215,9 @@ def _linear_bwd_fp8(
requires_dgrad: bool,
requires_wgrad: bool,
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
):
dgrad, wgrad = None, None
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -202,6 +236,9 @@ def _linear_bwd_fp8(
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
)
if parallel_mode == "column" and tensor_parallel:
dgrad = allreduce(dgrad, tp_group)
if requires_wgrad:
if not fp8_meta["recipe"].override_linear_precision.wgrad:
wgrad = fp8_gemm(
......@@ -236,6 +273,9 @@ def _linear_bwd_non_fp8(
requires_bgrad: bool,
requires_dgrad: bool,
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
gelu_input: Union[paddle.Tensor, None] = None,
activation: str = "",
):
......@@ -255,6 +295,9 @@ def _linear_bwd_non_fp8(
gelu_input=gelu_input,
grad=True,
)
if parallel_mode == "column" and tensor_parallel:
dgrad = allreduce(dgrad, tp_group)
if requires_wgrad:
wgrad, bgrad, _ = gemm(
inputmat,
......@@ -288,6 +331,9 @@ def _linear_bwd(
fp8_meta: Dict[str, Any],
requires_dgrad: bool,
activation_dtype: paddle.dtype,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
):
dgrad, wgrad, bgrad = None, None, None
requires_wgrad = not weight.stop_gradient
......@@ -307,6 +353,9 @@ def _linear_bwd(
requires_dgrad,
requires_wgrad,
activation_dtype,
parallel_mode,
tensor_parallel,
tp_group,
)
else:
dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
......@@ -316,6 +365,9 @@ def _linear_bwd(
requires_bgrad,
requires_dgrad,
activation_dtype,
parallel_mode,
tensor_parallel,
tp_group,
)
return dgrad, wgrad, bgrad
......@@ -335,6 +387,10 @@ class _Linear(paddle.autograd.PyLayer):
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
is_grad_enabled: bool,
parallel_mode: Union[str, None],
tensor_parallel: bool,
tp_group: Union[dist_group_type, None],
tp_size: int,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = weight.shape[-1]
......@@ -385,6 +441,9 @@ class _Linear(paddle.autograd.PyLayer):
fp8_calibration,
fp8_meta,
activation_dtype,
parallel_mode,
tensor_parallel,
tp_group,
is_grad_enabled,
)
......@@ -402,6 +461,10 @@ class _Linear(paddle.autograd.PyLayer):
ctx.fp8_meta = fp8_meta
ctx.use_bias = use_bias
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tensor_parallel = tensor_parallel
ctx.tp_group = tp_group
ctx.tp_size = tp_size
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
......@@ -411,6 +474,8 @@ class _Linear(paddle.autograd.PyLayer):
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
ctx.fp8_meta,
ctx.tp_group,
ctx.tp_size,
name="_Linear"):
(
inputmat,
......@@ -444,6 +509,9 @@ class _Linear(paddle.autograd.PyLayer):
ctx.fp8_meta,
ctx.requires_dgrad,
ctx.activation_dtype,
ctx.parallel_mode,
ctx.tensor_parallel,
ctx.tp_group,
)
if not ctx.fp8_enabled:
......@@ -474,6 +542,8 @@ class Linear(TransformerEngineBaseLayer):
out_features: int,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
parallel_mode: Optional[str] = None,
tp_group: Union[dist_group_type, None] = None,
backend: str = 'transformer_engine',
) -> None:
super().__init__()
......@@ -484,28 +554,56 @@ class Linear(TransformerEngineBaseLayer):
self._bias_attr = bias_attr
self._dtype = self._helper.get_default_dtype()
# Set parallel configs
self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
enable_tp=parallel_mode
is not None)
self.tensor_parallel = self.tp_size > 1
self.parallel_mode = parallel_mode
assert (self.parallel_mode
in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
# Initialize weight parameter
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
self.weight = self.create_parameter(
shape=[out_features, in_features]
if self.backend == 'transformer_engine' else [in_features, out_features],
shape=[self.out_features, self.in_features]
if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
self.backend)
# Initialize bias parameter
self.has_bias = self._bias_attr is not False
use_default_bias = self._bias_attr is None or self._bias_attr is True
if self.has_bias:
self.bias = self.create_parameter(
shape=[out_features],
shape=[self.out_features],
attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
if parallel_mode == "column":
set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
else:
self.bias = None
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
self.gemm_bias_fused_add = False
else:
self.gemm_bias_fused_add = True
def _te_forward(
self,
inp: paddle.Tensor,
......@@ -521,15 +619,22 @@ class Linear(TransformerEngineBaseLayer):
out = _Linear.apply(
self.weight,
inp,
self.bias,
self.has_bias,
self.bias if self.gemm_bias_fused_add else None,
self.has_bias and self.gemm_bias_fused_add,
self.fp8_enabled,
self.fp8_calibration,
self.fp8_meta,
self.activation_dtype,
paddle.is_grad_enabled(),
self.parallel_mode,
self.tensor_parallel,
self.tp_group,
self.tp_size,
)
if not self.gemm_bias_fused_add:
out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)
return out
def _pd_forward(
......@@ -537,7 +642,13 @@ class Linear(TransformerEngineBaseLayer):
inp: paddle.Tensor,
) -> paddle.Tensor:
"""Calls Paddle OP"""
return F.linear(inp, self.weight, self.bias)
if self.parallel_mode == 'column' and self.tensor_parallel:
inp = identity(inp, self.tp_group)
out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
out = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
return out
def forward(self, *args, **kwargs):
"""forward"""
......
......@@ -7,15 +7,11 @@ from typing import Optional, Union
import paddle
from transformer_engine.paddle.constants import (
AttnMaskTypes,
LayerTypes,
)
from transformer_engine.paddle.layer import (LayerNormMLP, LayerNorm, MultiHeadAttention)
from .base import TransformerEngineBaseLayer
from . import LayerNormMLP, LayerNorm, MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
class TransformerLayer(TransformerEngineBaseLayer):
class TransformerLayer(paddle.nn.Layer):
r"""
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
......@@ -64,6 +60,16 @@ class TransformerLayer(TransformerEngineBaseLayer):
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
if set to `True`, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
def __init__(self,
......@@ -82,6 +88,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
layer_type: str = "encoder",
zero_centered_gamma: bool = False,
activation: str = 'gelu',
set_parallel_mode: bool = False,
tp_group: Optional[dist_group_type] = None,
backend: str = 'transformer_engine') -> None:
super().__init__()
......@@ -90,6 +98,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.self_attn_mask_type = self_attn_mask_type
self.set_parallel_mode = set_parallel_mode
self.tp_group = tp_group
assert (self_attn_mask_type
in AttnMaskTypes), f"self_attn_mask_type {self_attn_mask_type} not supported"
......@@ -107,6 +117,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
"params_dtype": params_dtype,
"return_layernorm_output": apply_residual_connection_post_layernorm,
"zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode,
"tp_group": tp_group,
"backend": backend,
}
......@@ -136,6 +148,8 @@ class TransformerLayer(TransformerEngineBaseLayer):
activation=activation,
return_layernorm_output=apply_residual_connection_post_layernorm,
zero_centered_gamma=zero_centered_gamma,
set_parallel_mode=set_parallel_mode,
tp_group=tp_group,
backend=backend,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment