Unverified Commit 71725099 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Add RMSNorm, RoPE and SwiGLU (#599)



* use separate qkv
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

add support for GQA
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

change rtol
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

fix reshape issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

add rmsnorm and rotary position embedding
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

update rmsnorm
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

refactor layernorm and rmsnorm
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

support swiglu
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

add fused rope
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

add rope api to __init__
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

fix fp8 dtype issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* simplify ut cases
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* Update transformer_engine/paddle/layer/attention.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarShijie <505749828@qq.com>

* fix name issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>
Signed-off-by: default avatarShijie <505749828@qq.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 2187a8f3
This diff is collapsed.
......@@ -5,12 +5,6 @@
import struct
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import numpy as np
import paddle
import paddle.nn.functional as F
......@@ -34,6 +28,10 @@ from transformer_engine.paddle.cpp_extensions import (
cast_transpose_bgrad,
te_gelu,
gelu_fp8,
swiglu,
swiglu_fp8,
swiglu_pd,
dswiglu,
dgelu_cast_transpose_bgrad_fp8,
layernorm_fwd_fp8,
layernorm_fwd,
......@@ -62,9 +60,9 @@ GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816,
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)]
SELF_ATTN_CASES = [(2, 512, 12, 64)]
CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)]
FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
......@@ -296,6 +294,55 @@ class TestActivation:
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)
@staticmethod
def test_swiglu_bf16():
"""
Test BF16 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a)
assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype):
"""
Test FP8 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
fp8_meta = create_fp8_meta()
swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
swiglu_out = cast_from_fp8(swiglu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
swiglu_ref = swiglu_pd(a)
assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01)
@staticmethod
def test_swiglu_bwd():
"""
Test SwiGLU Backward
"""
# y = SwiGLU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
x.stop_gradient = False
y = swiglu_pd(x)
y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
class TestGemm:
"""
......
......@@ -4,6 +4,15 @@
"""Transformer Engine bindings for Paddle"""
from .fp8 import fp8_autocast
from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax,
DotProductAttention, MultiHeadAttention, TransformerLayer)
from .layer import (
Linear,
LayerNorm,
LayerNormLinear,
LayerNormMLP,
FusedScaleMaskSoftmax,
DotProductAttention,
MultiHeadAttention,
TransformerLayer,
RotaryPositionEmbedding,
)
from .recompute import recompute
......@@ -6,6 +6,7 @@
import math
from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta
......@@ -328,6 +329,56 @@ def gelu_fp8(
return out
def swiglu(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Non FP8 SWIGLU"""
return tex.te_swiglu(
inp,
int(otype),
)
def swiglu_pd(inp: paddle.Tensor,) -> paddle.Tensor:
"""Native SWIGLU"""
gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
out = F.silu(gate_out) * up_out
return out
def swiglu_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""SWIGLU + FP8 cast"""
out, _, _ = tex.te_swiglu_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)
return out
def dswiglu(
grad_output: paddle.Tensor,
swiglu_input: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""dSWIGLU"""
return tex.te_dswiglu(
grad_output,
swiglu_input,
int(otype),
)
def dgelu_cast_transpose_bgrad_fp8(
grad_output: paddle.Tensor,
gelu_input: paddle.Tensor,
......@@ -404,9 +455,10 @@ def rmsnorm_fwd(
eps: float,
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin)
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma)
def rmsnorm_fwd_fp8(
......@@ -417,12 +469,13 @@ def rmsnorm_fwd_fp8(
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
int(otype), sm_margin)
int(otype), sm_margin, zero_centered_gamma)
return out, rsigma
......@@ -432,9 +485,10 @@ def rmsnorm_bwd(
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma)
def mask_to_cu_seqlens(
......
......@@ -218,6 +218,66 @@ std::vector<paddle::Tensor> te_gelu(const paddle::Tensor &input, int64_t otype)
return {output};
}
std::vector<paddle::Tensor> te_swiglu(const paddle::Tensor &input, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype));
nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input,
int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");
size_t M = shape[0];
size_t N = shape[1];
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype()));
auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype()));
auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype()));
nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output,
const paddle::Tensor &gelu_input,
const paddle::Tensor &scale,
......@@ -406,7 +466,9 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad
std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight, float eps, int64_t otype,
int64_t sm_margin) {
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
......@@ -448,14 +510,16 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin) {
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");
size_t N = shape[0];
size_t H = shape[1];
auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
......@@ -487,7 +551,10 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin) {
const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());
......@@ -1374,6 +1441,25 @@ PD_BUILD_OP(te_gelu)
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu));
PD_BUILD_OP(te_swiglu)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu));
PD_BUILD_OP(te_swiglu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8));
PD_BUILD_OP(te_dswiglu)
.Inputs({"Grad", "Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu));
PD_BUILD_OP(te_cast_transpose_bgrad_dgelu)
.Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"})
......@@ -1404,20 +1490,21 @@ PD_BUILD_OP(te_layernorm_bwd)
PD_BUILD_OP(te_rmsnorm_fwd)
.Inputs({"Input", "Weight"})
.Outputs({"Output", "InvVariance"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd));
PD_BUILD_OP(te_rmsnorm_fwd_fp8)
.Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "InvVariance", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t"})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t",
"zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8));
PD_BUILD_OP(te_rmsnorm_bwd)
.Inputs({"Dz", "X", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma"})
.Attrs({"sm_margin: int64_t"})
.Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd));
PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Layer level Paddle APIs"""
from .attention import DotProductAttention, MultiHeadAttention
from .attention import DotProductAttention, MultiHeadAttention, RotaryPositionEmbedding
from .layernorm import LayerNorm
from .layernorm_linear import LayerNormLinear
from .layernorm_mlp import LayerNormMLP
......
......@@ -10,6 +10,10 @@ from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None
import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear
......@@ -30,7 +34,7 @@ from ..distributed import get_tp_group_and_world_size, track_rng_state
from ..utils import attention_mask_func, divide
from ..recompute import recompute
__all__ = ["DotProductAttention", "MultiHeadAttention"]
__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"]
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
......@@ -47,6 +51,81 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])
class RotaryPositionEmbedding(paddle.nn.Layer):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""
def __init__(
self,
dim: int,
max_position_embeddings: int,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
max_position_embeddings: int
max_position_embeddings before position interpolation
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.inv_freq = 1.0 / (10000**(paddle.cast(paddle.arange(0, dim, 2), dtype='float32') /
self.dim))
self._set_cos_sin_cache(seq_len=max_position_embeddings)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
# [seq_len]
t = paddle.arange(seq_len, dtype="float32")
# [seq_len, dim/2]
freqs = paddle.einsum("i,j->ij", t, self.inv_freq)
# [seq_len, dim]
emb = paddle.concat([freqs, freqs], axis=-1)
# [1, seqlen, 1, dim]
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
def forward(self, max_seq_len: int):
"""
Create rotary position embedding frequencies
Parameters
----------
max_seq_len: int
sequence length of a sample
"""
cos = self.cos_cached[:, :, :max_seq_len, ...]
sin = self.sin_cached[:, :, :max_seq_len, ...]
return (cos, sin)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
"""Applies rotary positional embedding to the input."""
if position_ids is None:
# Note: Only for LlamaForCausalLMPipe model pretraining
cos = cos[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim]
sin = sin[:, :q.shape[1], :, :] # [bs, seq_len, 1, dim]
else:
cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim]
sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input"""
......@@ -450,6 +529,8 @@ class MultiHeadAttention(paddle.nn.Layer):
whether to apply layernorm to the input.
attention_type: {'self', 'cross'}, default = `self`
type of attention operation.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
zero_centered_gamma: bool, default = `False`
whether to zero initialize the gamma of the layernorm operation.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
......@@ -491,11 +572,13 @@ class MultiHeadAttention(paddle.nn.Layer):
layernorm_epsilon: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
max_sequence_length: Optional[int] = None,
attn_mask_type: str = "causal",
params_dtype: Optional[paddle.dtype] = None,
return_layernorm_output: bool = False,
input_layernorm: bool = False,
attention_type: str = "self",
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
......@@ -509,6 +592,7 @@ class MultiHeadAttention(paddle.nn.Layer):
self.attention_type = attention_type
self.return_layernorm_output = return_layernorm_output
self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype
self.max_sequence_length = max_sequence_length
self.weight_attr = weight_attr
self.bias_attr = bias_attr
self.attn_mask_type = attn_mask_type
......@@ -544,6 +628,7 @@ class MultiHeadAttention(paddle.nn.Layer):
weight_attr=self.weight_attr,
bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
......@@ -571,6 +656,7 @@ class MultiHeadAttention(paddle.nn.Layer):
weight_attr=self.weight_attr,
bias_attr=self.bias_attr,
return_layernorm_output=return_layernorm_output,
normalization=normalization,
zero_centered_gamma=zero_centered_gamma,
parallel_mode=qkv_parallel_mode,
sequence_parallel=self.sequence_parallel,
......@@ -628,6 +714,7 @@ class MultiHeadAttention(paddle.nn.Layer):
hidden_states: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
encoder_output: Optional[paddle.Tensor] = None,
rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
......@@ -645,6 +732,9 @@ class MultiHeadAttention(paddle.nn.Layer):
Boolean tensor used to mask out softmax input when not using attention.
encoder_output : Optional[paddle.Tensor], default = `None`
Output of the encoder layer.
rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
only support no_bias type currently, {`no_bias`}
core_attention_bias: Optional[paddle.Tensor], default = `None`
......@@ -675,8 +765,8 @@ class MultiHeadAttention(paddle.nn.Layer):
if input_dim == 2:
# hidden_states: [b * s_q, hidden_size]
# need to get max_seq_len from attention_mask
assert attention_mask is not None
max_seq_len = attention_mask.shape[-1]
assert self.max_sequence_length is not None, "max_sequence_length must be provided"
max_seq_len = self.max_sequence_length
elif input_dim == 3:
# hidden_states: [b, s_q, hidden_size]
max_seq_len = hidden_states.shape[1]
......@@ -723,30 +813,6 @@ class MultiHeadAttention(paddle.nn.Layer):
shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
for x in (query_layer, key_layer, value_layer))
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention:
context_layer = recompute(
self.core_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
set_zero,
use_reentrant=False,
)
else:
context_layer = self.core_attention(
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
else: # cross attention
mixed_kv_layer = self.key_value(
encoder_output,
......@@ -785,29 +851,46 @@ class MultiHeadAttention(paddle.nn.Layer):
-1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
])
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention:
context_layer = recompute(
self.core_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
set_zero,
use_reentrant=False,
)
else:
context_layer = self.core_attention(
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if fused_rotary_position_embedding is None:
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, q_pos_emb,
k_pos_emb)
else:
query_layer, key_layer, _ = fused_rotary_position_embedding(
query_layer,
key_layer,
v=None,
sin=k_pos_emb,
cos=q_pos_emb,
position_ids=None,
use_neox_rotary_style=False,
)
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention:
context_layer = recompute(
self.core_attention,
query_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
set_zero,
use_reentrant=False,
)
else:
context_layer = self.core_attention(
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
)
if input_dim == 3:
context_layer = paddle.reshape(
......
......@@ -17,6 +17,9 @@ from ..cpp_extensions import (
layernorm_fwd,
layernorm_fwd_fp8,
layernorm_bwd,
rmsnorm_fwd_fp8,
rmsnorm_fwd,
rmsnorm_bwd,
)
from .base import TransformerEngineBaseLayer
......@@ -44,82 +47,129 @@ from ..utils import (
__all__ = ["LayerNormLinear"]
def _layernorm_fwd_fp8_cast(
def _apply_normalization_fwd(
normalization: str,
inputmat: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
norm_weight: paddle.Tensor,
norm_bias: Union[paddle.Tensor, None],
out_fp8_index: FP8FwdTensors,
eps: float,
fp8_enabled: bool,
fp8_meta: Dict[str, Any],
activation_dtype: paddle.dtype,
return_layernorm_output: bool,
fwd_ln_sm_margin: int,
return_norm_output: bool,
fwd_norm_sm_margin: int,
zero_centered_gamma: bool,
):
"""Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path"""
ln_weight = cast_if_needed_inplace(ln_weight, activation_dtype)
ln_bias = cast_if_needed_inplace(ln_bias, activation_dtype)
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
if normalization == "RMSNorm":
assert norm_bias is None, "RMSNorm does not support bias!"
norm_weight = cast_if_needed_inplace(norm_weight, activation_dtype)
if norm_bias is not None:
norm_bias = cast_if_needed_inplace(norm_bias, activation_dtype)
norm_kwargs = {
"inp": inputmat,
"weight": norm_weight,
"eps": eps,
"otype": TE_DType[activation_dtype],
"sm_margin": fwd_norm_sm_margin,
"zero_centered_gamma": zero_centered_gamma,
}
fwd_normalization_funcs = {
('LayerNorm', True, True): layernorm_fwd,
('LayerNorm', True, False): layernorm_fwd_fp8,
('LayerNorm', False, True): layernorm_fwd,
('LayerNorm', False, False): layernorm_fwd,
('RMSNorm', True, True): rmsnorm_fwd,
('RMSNorm', True, False): rmsnorm_fwd_fp8,
('RMSNorm', False, True): rmsnorm_fwd,
('RMSNorm', False, False): rmsnorm_fwd,
}
if normalization == "LayerNorm":
norm_kwargs["bias"] = norm_bias
norm_fwd_func = fwd_normalization_funcs[(normalization, fp8_enabled, return_norm_output)]
if fp8_enabled:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
ln_out, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
out_fp8_index,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
)
ln_out_return = ln_out
else:
ln_out_return, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
out_fp8_index,
fp8_dtype_forward,
)
if not return_norm_output:
fp8_kwargs = {
"fp8_meta_tensor": fp8_meta["scaling_fwd"],
"fp8_tensor": out_fp8_index,
"otype": fp8_dtype_forward,
}
norm_kwargs.update(fp8_kwargs)
out_tuple = norm_fwd_func(**norm_kwargs)
if normalization == "LayerNorm":
norm_out_return, mu, rsigma = out_tuple
else: # RMSNorm
norm_out_return, rsigma = out_tuple
mu = None
if fp8_enabled and return_norm_output:
norm_out = cast_to_fp8(
norm_out_return,
fp8_meta["scaling_fwd"],
out_fp8_index,
fp8_dtype_forward,
)
else:
ln_out, mu, rsigma = layernorm_fwd(inputmat, ln_weight, ln_bias, eps,
TE_DType[activation_dtype], fwd_ln_sm_margin,
zero_centered_gamma)
ln_out_return = ln_out
norm_out = norm_out_return
return (
ln_out_return,
ln_out,
norm_out_return,
norm_out,
mu,
rsigma,
)
def _layernorm_bwd(
def _apply_normalization_bwd(
normalization: str,
inputmat: paddle.Tensor,
dgrad: paddle.Tensor,
ln_weight: paddle.Tensor,
mu: paddle.Tensor,
norm_weight: paddle.Tensor,
mu: Union[paddle.Tensor, None],
rsigma: paddle.Tensor,
grad_ln_out_return: paddle.Tensor,
return_layernorm_output: bool,
bwd_ln_sm_margin: int,
grad_norm_out_return: paddle.Tensor,
return_norm_output: bool,
bwd_norm_sm_margin: int,
zero_centered_gamma: bool,
):
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
if normalization == "RMSNorm":
assert mu is None, "RMSNorm does not support bias!"
# LayerNorm gradient
d_ln_out = dgrad.reshape(inputmat.shape)
d_norm_out = dgrad.reshape(inputmat.shape)
# Residual gradient
if return_layernorm_output:
d_ln_out = d_ln_out + grad_ln_out_return.reshape(d_ln_out.shape)
return layernorm_bwd(d_ln_out, inputmat, mu, rsigma, ln_weight, bwd_ln_sm_margin,
zero_centered_gamma)
if return_norm_output:
d_norm_out = d_norm_out + grad_norm_out_return.reshape(d_norm_out.shape)
norm_bwd_func = layernorm_bwd if normalization == "LayerNorm" else rmsnorm_bwd
norm_bwd_kwargs = {
"dz": d_norm_out,
"x": inputmat,
"rsigma": rsigma,
"gamma": norm_weight,
"sm_margin": bwd_norm_sm_margin,
"zero_centered_gamma": zero_centered_gamma,
}
if normalization == "LayerNorm":
norm_bwd_kwargs["mu"] = mu
out_tuple = norm_bwd_func(**norm_bwd_kwargs)
if normalization == "LayerNorm":
dxmat, dgamma, dbeta = out_tuple
else: # RMSNorm
dxmat, dgamma = out_tuple
dbeta = None
return dxmat, dgamma, dbeta
class _LayerNormLinear(paddle.autograd.PyLayer):
......@@ -130,7 +180,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx,
inp: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
ln_bias: Union[paddle.Tensor, None],
weight: paddle.Tensor,
weight_fp8: Optional[paddle.Tensor],
weight_t_fp8: Optional[paddle.Tensor],
......@@ -146,6 +196,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
parallel_mode: Union[str, None],
tensor_parallel: bool,
sequence_parallel: bool,
......@@ -153,6 +204,10 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
tp_size: int,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm":
assert ln_bias is None, "RMSNorm does not support bias!"
else: # LayerNorm
assert ln_bias is not None, "LayerNorm requires bias!"
# Make sure input dimensions are compatible
in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "GEMM not possible"
......@@ -167,7 +222,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ln_out,
mu,
rsigma,
) = _layernorm_fwd_fp8_cast(
) = _apply_normalization_fwd(
normalization,
inputmat,
ln_weight,
ln_bias,
......@@ -232,9 +288,11 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
ctx.requires_dgrad = not inp.stop_gradient
ctx.requires_wgrad = not weight.stop_gradient
ctx.requires_bgrad = use_bias and not bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
ctx.has_ln_bias = ln_bias is not None
ctx.normalization = normalization
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))
......@@ -314,7 +372,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bgrad = bgrad_
# LayerNorm Bwd
dxmat, dgamma, dbeta = _layernorm_bwd(
dxmat, dgamma, dbeta = _apply_normalization_bwd(
ctx.normalization,
inputmat,
dgrad,
ln_weight,
......@@ -328,6 +387,8 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
bgrad = bgrad if ctx.requires_bgrad else None
bgrad_out = (bgrad,) if ctx.use_bias else ()
dbeta = dbeta if ctx.requires_ln_bgrad else None
dbeta_out = (dbeta,) if ctx.has_ln_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
weight_cache_grad = ()
......@@ -338,7 +399,7 @@ class _LayerNormLinear(paddle.autograd.PyLayer):
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None,
*dbeta_out,
wgrad if ctx.requires_wgrad else None,
*weight_cache_grad,
*bgrad_out,
......@@ -361,6 +422,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
......@@ -395,6 +458,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
normalization: str = 'LayerNorm',
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
parallel_mode: Optional[str] = None,
......@@ -407,6 +471,8 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.in_features = in_features
self.out_features = out_features
self.eps = eps
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma
self.backend = backend
......@@ -439,17 +505,20 @@ class LayerNormLinear(TransformerEngineBaseLayer):
dtype=self._dtype,
is_bias=False,
)
self.ln_bias = self.create_parameter(
shape=[self.in_features],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
if self.normalization != "RMSNorm":
self.ln_bias = self.create_parameter(
shape=[self.in_features],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
else:
self.ln_bias = None
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight)
mark_as_sequence_parallel_parameter(self.ln_bias)
if self.ln_bias is not None:
mark_as_sequence_parallel_parameter(self.ln_bias)
# Initialize Linear weight parameter
with track_rng_state(enable=self.tensor_parallel):
......@@ -534,6 +603,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.parallel_mode,
self.tensor_parallel,
self.sequence_parallel,
......@@ -566,19 +636,24 @@ class LayerNormLinear(TransformerEngineBaseLayer):
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
if self.normalization == "RMSNorm":
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
norm_out = inp * norm * self.ln_weight
else: # LayerNorm
norm_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
if self.parallel_mode == 'column' and self.tensor_parallel:
ln_out = identity(ln_out, self.tp_group)
out = F.linear(ln_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
norm_out = identity(norm_out, self.tp_group)
out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None)
if self.parallel_mode == 'row' and self.tensor_parallel:
out, _ = allreduce(out, self.tp_group)
out = out + self.bias if self.bias is not None else out
if self.return_layernorm_output:
return out, ln_out
return out, norm_out
return out
def forward(self, *args, **kwargs):
......
......@@ -12,13 +12,17 @@ import paddle.nn.functional as F
from paddle.nn.initializer import Constant
from .base import TransformerEngineBaseLayer
from .layernorm_linear import _layernorm_fwd_fp8_cast, _layernorm_bwd
from .layernorm_linear import _apply_normalization_fwd, _apply_normalization_bwd
from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8
from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type
from ..cpp_extensions import (
cast_from_fp8,
dgelu_cast_transpose_bgrad_fp8,
gelu_fp8,
swiglu_fp8,
swiglu,
dswiglu,
cast_transpose_bgrad,
dgelu_cast_transpose_bgrad_fp8,
)
from ..distributed import (
allreduce,
......@@ -91,13 +95,22 @@ def _mlp_forward(
is_grad_enabled,
is_first_microbatch,
)
gelu_out = gelu_fp8(
fc1_out,
fp8_meta["scaling_fwd"],
fc2_input_fp8_index,
fp8_dtype_forward,
)
if activation == "gelu":
gelu_out = gelu_fp8(
fc1_out,
fp8_meta["scaling_fwd"],
fc2_input_fp8_index,
fp8_dtype_forward,
)
elif activation == "swiglu":
gelu_out = swiglu_fp8(
fc1_out,
fp8_meta["scaling_fwd"],
fc2_input_fp8_index,
fp8_dtype_forward,
)
else:
raise NotImplementedError("Activation type " + activation + " is not supported!")
fc2_out, fc2_weight_t_fp8 = _linear_fwd_fp8(
gelu_out,
......@@ -118,7 +131,7 @@ def _mlp_forward(
is_first_microbatch,
)
else:
fc1_out, gelu_out = _linear_fwd_non_fp8(
fc1_outputs = _linear_fwd_non_fp8(
inputmat,
inputmat_fp8_index,
fc1_weight,
......@@ -135,6 +148,14 @@ def _mlp_forward(
activation=activation,
)
if activation == "gelu":
fc1_out, gelu_out = fc1_outputs
elif activation == "swiglu":
fc1_out = fc1_outputs
gelu_out = swiglu(fc1_out, TE_DType[activation_dtype])
else:
raise NotImplementedError("Activation type " + activation + " is not supported!")
fc2_out = _linear_fwd_non_fp8(
gelu_out,
fc2_input_fp8_index,
......@@ -234,14 +255,23 @@ def _mlp_backward(
tp_group,
)
# GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
fc2_dgrad,
fc1_out,
fp8_meta["scaling_bwd"],
fc1_grad_output_fp8_index,
fp8_dtype_backward,
)
if activation == "gelu":
# GELU Bwd
dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8(
fc2_dgrad,
fc1_out,
fp8_meta["scaling_bwd"],
fc1_grad_output_fp8_index,
fp8_dtype_backward,
)
elif activation == "swiglu":
dgelu = dswiglu(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype])
fc1_bgrad_, dgelu, dgelu_t = cast_transpose_bgrad(
dgelu,
fp8_meta["scaling_bwd"],
fc1_grad_output_fp8_index,
fp8_dtype_backward,
)
if requires_fc1_bgrad:
fc1_bgrad = fc1_bgrad_
......@@ -301,6 +331,10 @@ def _mlp_backward(
gelu_input=fc1_out,
activation=activation,
)
if activation == "swiglu":
dgelu = dswiglu(dgelu, fc1_out, TE_DType[dgelu.dtype])
fc1_dgrad, fc1_wgrad, fc1_bgrad = _linear_bwd_non_fp8(
fc1_input,
fc1_weight,
......@@ -331,7 +365,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx,
inp: paddle.Tensor,
ln_weight: paddle.Tensor,
ln_bias: paddle.Tensor,
ln_bias: Union[paddle.Tensor, None],
fc1_weight: paddle.Tensor,
fc1_weight_fp8: Optional[paddle.Tensor],
fc1_weight_t_fp8: Optional[paddle.Tensor],
......@@ -352,6 +386,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
normalization: str,
activation: str,
set_parallel_mode: bool,
tensor_parallel: bool,
......@@ -360,6 +395,10 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
tp_size: int,
is_first_microbatch: bool,
) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]:
if normalization == "RMSNorm":
assert ln_bias is None, "RMSNorm does not support bias!"
else: # LayerNorm
assert ln_bias is not None, "LayerNorm requires bias!"
# Make sure input dimensions are compatible
in_features = ln_weight.shape[0]
assert inp.shape[-1] == in_features, "GEMM not possible"
......@@ -370,7 +409,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
assert_dim_for_fp8_forward_exec(fc2_weight)
# only support gelu for now
assert activation == 'gelu'
assert activation in ["gelu", "swiglu"], "Only gelu and swiglu are supported for now"
# LayerNorm Fwd + FP8 Cast
(
......@@ -378,7 +417,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ln_out,
mu,
rsigma,
) = _layernorm_fwd_fp8_cast(
) = _apply_normalization_fwd(
normalization,
inputmat,
ln_weight,
ln_bias,
......@@ -463,9 +503,11 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient
ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient
ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient
ctx.requires_ln_bgrad = not ln_bias.stop_gradient
ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient
ctx.requires_ln_wgrad = not ln_weight.stop_gradient
ctx.is_first_microbatch = is_first_microbatch
ctx.has_ln_bias = ln_bias is not None
ctx.normalization = normalization
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1]))
......@@ -549,7 +591,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_bgrad = fc2_bgrad_
# LayerNorm Bwd
dxmat, dgamma, dbeta = _layernorm_bwd(
dxmat, dgamma, dbeta = _apply_normalization_bwd(
ctx.normalization,
inputmat,
fc1_dgrad,
ln_weight,
......@@ -565,6 +608,8 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
fc2_bgrad = fc2_bgrad if ctx.requires_fc2_bgrad else None
fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else ()
fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else ()
dbeta = dbeta if ctx.requires_ln_bgrad else None
dbeta_out = (dbeta,) if ctx.has_ln_bias else ()
if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
fc1_weight_cache_grad = ()
......@@ -577,7 +622,7 @@ class _LayerNormMLP(paddle.autograd.PyLayer):
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma if ctx.requires_ln_wgrad else None,
dbeta if ctx.requires_ln_bgrad else None,
*dbeta_out,
fc1_wgrad if ctx.requires_fc1_wgrad else None,
*fc1_weight_cache_grad,
*fc1_bgrad_out,
......@@ -604,6 +649,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
optional `paddle.ParamAttr` for weight.
bias_attr: Union[paddle.ParamAttr, None, bool], default = None
optional `paddle.ParamAttr` for bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
......@@ -641,6 +688,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
normalization: str = "LayerNorm",
activation: str = "gelu",
return_layernorm_output: bool = False,
zero_centered_gamma: bool = False,
......@@ -654,6 +702,8 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.eps = eps
self.normalization = normalization
assert normalization in ["LayerNorm", "RMSNorm"], "Normalization type not supported"
self.activation = activation
self.return_layernorm_output = return_layernorm_output
self.zero_centered_gamma = zero_centered_gamma
......@@ -684,22 +734,31 @@ class LayerNormMLP(TransformerEngineBaseLayer):
is_bias=False,
)
self.ln_bias = self.create_parameter(
shape=[self.hidden_size],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
if self.normalization != "RMSNorm":
self.ln_bias = self.create_parameter(
shape=[self.hidden_size],
attr=paddle.ParamAttr(initializer=Constant(value=0.0)),
dtype=self._dtype,
is_bias=True,
)
else:
self.ln_bias = None
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.ln_weight)
mark_as_sequence_parallel_parameter(self.ln_bias)
if self.ln_bias is not None:
mark_as_sequence_parallel_parameter(self.ln_bias)
# FC1 weights
if self.activation in ["swiglu"]:
fc1_output_features = self.size_per_partition * 2
else:
fc1_output_features = self.size_per_partition
with track_rng_state(enable=self.tensor_parallel):
self.fc1_weight = self.create_parameter(
shape=[self.size_per_partition, self.hidden_size] if self.backend
== 'transformer_engine' else [self.hidden_size, self.size_per_partition],
shape=[fc1_output_features, self.hidden_size] if self.backend
== 'transformer_engine' else [self.hidden_size, fc1_output_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
......@@ -717,7 +776,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
if self.has_bias:
self.fc1_bias = self.create_parameter(
shape=[self.size_per_partition],
shape=[fc1_output_features],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
......@@ -809,6 +868,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.activation,
self.set_parallel_mode,
self.tensor_parallel,
......@@ -842,14 +902,18 @@ class LayerNormMLP(TransformerEngineBaseLayer):
warnings.warn(
"`is_first_microbatch` is not supported for paddle backend and is ignored.")
ln_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
if self.normalization == "RMSNorm":
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
norm_out = inp * norm * self.ln_weight
else: # LayerNorm
norm_out = F.layer_norm(x=inp,
normalized_shape=inp.shape[-1],
weight=self.ln_weight,
bias=self.ln_bias,
epsilon=self.eps)
if self.set_parallel_mode and self.tensor_parallel:
ln_out = identity(ln_out, self.tp_group)
fc1_out = F.linear(ln_out, self.fc1_weight, self.fc1_bias)
norm_out = identity(norm_out, self.tp_group)
fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias)
act_func = get_paddle_act_func(self.activation)
act_out = act_func(fc1_out)
out = F.linear(act_out, self.fc2_weight,
......@@ -858,7 +922,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
out, _ = allreduce(out, self.tp_group)
out = out + self.fc2_bias if self.fc2_bias is not None else out
if self.return_layernorm_output:
return out, ln_out
return out, norm_out
return out
def forward(self, *args, **kwargs):
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""RMSNorm API"""
import os
from typing import Union, Tuple
import paddle
from paddle.nn.initializer import Constant
from ..constants import TE_DType
from ..cpp_extensions import rmsnorm_fwd, rmsnorm_bwd
from ..distributed import mark_as_sequence_parallel_parameter
__all__ = ["RMSNorm"]
class _RMSNorm(paddle.autograd.PyLayer):
"""functional RMSNorm"""
@staticmethod
def forward(
ctx,
inp: paddle.Tensor,
rmsnorm_weight: paddle.Tensor,
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
) -> paddle.Tensor:
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.shape[0]
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.reshape((-1, in_features))
rmsnorm_out, rsigma = rmsnorm_fwd(inputmat, rmsnorm_weight, eps, TE_DType[inp.dtype],
fwd_rmsnorm_sm_margin, zero_centered_gamma)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
ctx.requires_dx = not inp.stop_gradient
ctx.requires_dw = not rmsnorm_weight.stop_gradient
return rmsnorm_out.reshape(inp.shape)
@staticmethod
def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor()
d_rmsnorm_out = grad_output.reshape(inputmat.shape)
dxmat, dgamma = rmsnorm_bwd(d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma)
return (
dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None,
dgamma if ctx.requires_dw else None,
)
class RMSNorm(paddle.nn.Layer):
r"""
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{RMS_\varepsilon(x)} * \gamma
where
.. math::
RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
weight_attr: Union[paddle.ParamAttr, None], default = None
optional `paddle.ParamAttr` for weight.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in RMSNorm is initialized to 0 and
the RMSNorm formula changes to
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
backend to use for rmsnorm operation.
Parallelism parameters
----------------------
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
weight_attr: Union[paddle.ParamAttr, None] = None,
zero_centered_gamma: bool = False,
sequence_parallel: bool = False,
backend: str = "transformer_engine",
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.sequence_parallel = sequence_parallel
self.backend = backend
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
if not self._weight_attr:
self._weight_attr = paddle.ParamAttr(initializer=Constant(1.0))
self.weight = self.create_parameter(
shape=[hidden_size],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
if self.sequence_parallel:
mark_as_sequence_parallel_parameter(self.weight)
# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor:
return _RMSNorm.apply(
inp,
self.weight,
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.zero_centered_gamma,
)
def _pd_forward(
self,
inp: paddle.Tensor,
) -> paddle.Tensor:
if self.zero_centered_gamma:
raise NotImplementedError(
"Paddle backend does not support RMSNorm with zero_centered_gamma.")
norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps)
y = inp * norm * self.weight
return y
def forward(self, *args, **kwargs):
if self.backend == "transformer_engine":
return self._te_forward(*args, **kwargs)
if self.backend == "paddle":
return self._pd_forward(*args, **kwargs)
raise AttributeError(f"Backend {self.backend} not supported.")
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Transformer"""
from typing import Optional, Union
from typing import Optional, Tuple, Union
import warnings
import paddle
......@@ -60,6 +60,7 @@ class TransformerLayer(paddle.nn.Layer):
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option.
normalization: {'LayerNorm', 'RMSNorm'}, default = `LayerNorm`
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -111,11 +112,13 @@ class TransformerLayer(paddle.nn.Layer):
attention_dropout: float = 0.1,
weight_attr: Union[paddle.ParamAttr, None] = None,
bias_attr: Union[paddle.ParamAttr, None, bool] = None,
max_sequence_length: Optional[int] = None,
self_attn_mask_type: str = "causal",
params_dtype: Optional[paddle.dtype] = None,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
layer_type: str = "encoder",
normalization: str = "LayerNorm",
zero_centered_gamma: bool = False,
activation: str = 'gelu',
set_parallel_mode: bool = False,
......@@ -158,9 +161,11 @@ class TransformerLayer(paddle.nn.Layer):
common_attention_kwargs = {
"params_dtype": params_dtype,
"return_layernorm_output": apply_residual_connection_post_layernorm,
"normalization": normalization,
"zero_centered_gamma": zero_centered_gamma,
"set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel,
'max_sequence_length': max_sequence_length,
"tp_group": tp_group,
"num_gqa_groups": num_gqa_groups,
"rng_state_name": attention_dropout_rng_state_name,
......@@ -190,6 +195,7 @@ class TransformerLayer(paddle.nn.Layer):
eps=layernorm_epsilon,
weight_attr=weight_attr,
bias_attr=bias_attr,
normalization=normalization,
activation=activation,
return_layernorm_output=apply_residual_connection_post_layernorm,
zero_centered_gamma=zero_centered_gamma,
......@@ -223,6 +229,7 @@ class TransformerLayer(paddle.nn.Layer):
attention_mask: Optional[paddle.Tensor] = None,
encoder_output: Optional[paddle.Tensor] = None,
enc_dec_attn_mask: Optional[paddle.Tensor] = None,
rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
set_zero: bool = True,
......@@ -249,6 +256,9 @@ class TransformerLayer(paddle.nn.Layer):
enc_dec_attn_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`.
rotary_pos_emb : Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied
core_attention_bias_type: str, default = `no_bias`
core_attention_bias: Optional[paddle.Tensor], default = `None`
Bias tensor for Q * K.T
......@@ -284,6 +294,7 @@ class TransformerLayer(paddle.nn.Layer):
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
set_zero=set_zero,
rotary_pos_emb=rotary_pos_emb,
recompute_core_attention=recompute_core_attention,
is_first_microbatch=is_first_microbatch,
)
......
......@@ -7,6 +7,7 @@ from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
from .cpp_extensions import swiglu_pd
def cast_if_needed(tensor: Union[paddle.Tensor, None],
......@@ -48,6 +49,8 @@ def get_paddle_act_func(activation):
funcs = {
'gelu': F.gelu,
'relu': F.relu,
'silu': F.silu,
'swiglu': swiglu_pd,
}
if activation not in funcs:
raise "Activation type " + activation + " is not supported."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment