Unverified Commit bd7fd0a6 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Support GQA (#595)



* use separate qkv
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* add support for GQA
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* change rtol
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix reshape issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent e531cd2f
......@@ -3,7 +3,6 @@
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""
import math
import os
from utils import assert_allclose, is_fused_attention_supported
......@@ -785,7 +784,7 @@ class TestLayerNormMLP:
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
......@@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
head_size=head_size,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
self_attn_qkv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, 3, num_heads,
head_size)).astype(math_dtype)
cross_attn_q_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, num_heads,
head_size)).astype(math_dtype)
cross_attn_kv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, kv_seqlen, 2, num_heads,
head_size)).astype(math_dtype)
attn_q_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
attn_k_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
attn_v_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
......@@ -841,57 +834,36 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
norm_factor = math.sqrt(hidden_size // num_heads)
layer_te = te.DotProductAttention(norm_factor,
head_size = hidden_size // num_heads
layer_te = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='transformer_engine')
layer_pd = te.DotProductAttention(norm_factor,
layer_pd = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='paddle')
def calc_attn_output_and_grad(layer, q, kv, mask, dout):
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
_kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None
_k = paddle.to_tensor(k, stop_gradient=False)
_v = paddle.to_tensor(v, stop_gradient=False)
out = layer(_q, _kv, mask)
out = layer(_q, _k, _v, mask)
out.backward(dout)
return out, _q.grad, _kv.grad if _kv is not None else None
if attn_type == 'self':
out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask,
grad_out)
out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None,
attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
q_grad = qkv_grad[:, :, 0]
k_grad = qkv_grad[:, :, 1]
v_grad = qkv_grad[:, :, 2]
q_grad_ref = qkv_grad_ref[:, :, 0]
k_grad_ref = qkv_grad_ref[:, :, 1]
v_grad_ref = qkv_grad_ref[:, :, 2]
else:
out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input,
cross_attn_kv_input, attn_mask, grad_out)
out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input,
cross_attn_kv_input, attn_mask,
grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
return out, _q.grad, _k.grad, _v.grad
k_grad = kv_grad[:, :, 0]
v_grad = kv_grad[:, :, 1]
k_grad_ref = kv_grad_ref[:, :, 0]
v_grad_ref = kv_grad_ref[:, :, 1]
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
attn_v_input, attn_mask, grad_out)
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
......@@ -910,17 +882,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output):
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output):
"""
Test Transformer Encoder Layer
"""
......@@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
......@@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -1088,8 +1063,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
......@@ -1097,9 +1073,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output,
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output,
recompute_core_attention):
"""
Test Transformer Decoder Layer
......@@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bs2hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
encoder_input = paddle.normal(mean=0.0, std=0.1,
shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
encoder_output = paddle.normal(mean=0.0, std=0.1,
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.01,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
encoder_output = paddle.round(encoder_output * 1000) / 1000
grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
......@@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -1319,7 +1293,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=0.1)
atol=atol)
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
......@@ -1328,7 +1302,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
rtol=0.5,
atol=0.6)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
......
......@@ -5,6 +5,12 @@
import struct
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import numpy as np
import paddle
import paddle.nn.functional as F
......@@ -39,6 +45,8 @@ from transformer_engine.paddle.cpp_extensions import (
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
scaled_softmax_forward,
scaled_softmax_backward,
scaled_masked_softmax_forward,
......@@ -594,6 +602,7 @@ class TestFusedAttn:
self.q = _random(self.q_shape)
if self.attn_mode == "self_attn":
assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen"
self.kv = self.q
else:
self.kv = _random(self.kv_shape)
......@@ -774,6 +783,70 @@ class TestFusedAttn:
return out, q_grad, k_grad, v_grad
def _get_fused_attention_with_separate_qkv(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = "bshd_bshd_bshd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, rng_state = fused_attn_fwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
dq, dk, dv, _ = fused_attn_bwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
fused_attention_backend=fused_attention_backend,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
return out, dq, dk, dv
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
......@@ -857,6 +930,35 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype,
is_causal_masking):
"""
test flash attention forward + backward with separate qkv inputs
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
......
......@@ -792,6 +792,189 @@ def fused_attn_bwd_kvpacked(
return dq, dkv, dbias
def fused_attn_fwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for unpacked QKV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [
1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training:
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(shape=[
2,
], dtype=paddle.int64)
# execute kernel
tex.te_fused_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
rng_elts_per_thread,
)
return out, softmax_aux, rng_state
def fused_attn_bwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dk = paddle.empty(shape=k.shape, dtype=k.dtype)
dv = paddle.empty(shape=v.shape, dtype=v.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dk,
dv,
dbias,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dq, dk, dv, dbias
def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
......
......@@ -864,6 +864,183 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, const int64_t qkv_type,
int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(O);
} else { // TODO: support fp8
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::Tensor &O, const paddle::Tensor &dO,
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dK, // NOLINT
paddle::Tensor &dV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQ = MakeNvteTensor(dQ);
te_dK = MakeNvteTensor(dK);
te_dV = MakeNvteTensor(dV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
......@@ -1316,6 +1493,33 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked));
PD_BUILD_OP(te_fused_attn_fwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t", "rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd));
PD_BUILD_OP(te_fused_attn_bwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK",
"_dV", paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t"})
.SetInplaceMap({{"_dQ", "dQ"},
{"_dK", "dK"},
{"_dV", "dV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd));
PD_BUILD_OP(te_scaled_softmax_forward)
.Inputs({"input"})
.Outputs({"softmax_results"})
......
......@@ -22,6 +22,8 @@ from ..cpp_extensions import (
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
mask_to_cu_seqlens,
)
from ..distributed import get_tp_group_and_world_size, track_rng_state
......@@ -31,6 +33,20 @@ from ..recompute import recompute
__all__ = ["DotProductAttention", "MultiHeadAttention"]
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
"""
Used to repeat the key and value states for GQA.
The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
to (batch, seqlen, num_heads, head_size)
"""
batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input"""
......@@ -130,6 +146,50 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
return (dq, dkv, None, None, rest[0])
class FusedAttnFunc(paddle.autograd.PyLayer):
"""Function for FusedAttention with separate Q, K, V tensors"""
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
attn_mask_type, is_training, fused_attention_backend):
"""Forward function for FusedAttention with separate Q, K, V tensors"""
out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv,
is_training, max_seqlen_q, max_seqlen_kv,
qkv_dtype, fused_attention_backend, attn_bias,
attn_scale, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type)
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.set_zero = set_zero
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out
@staticmethod
def backward(ctx, d_out):
"""Backward function for FusedAttention with separate Q, K, V tensors"""
q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
d_out, softmax_aux, ctx.fused_attention_backend,
ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dq, dk, dv
if ctx.attn_bias_type == "no_bias":
return (dq, dk, dv, None, None)
# else, return (dq, dk, dv, dbias)
return (dq, dk, dv, None, None, rest[0])
class DotProductAttention(paddle.nn.Layer):
"""
Allows the model to jointly attend to information from different
......@@ -143,31 +203,51 @@ class DotProductAttention(paddle.nn.Layer):
Parameters
----------
norm_factor : float
normalization factor for the attention scores.
num_attention_heads: int
number of attention heads in the transformer layer.
kv_channels: int
number of channels in the key and value tensors.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
attention_type: {'self', 'cross'}, default = `self`
type of attention operation.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation.
"""
def __init__(self,
norm_factor: float,
num_attention_heads: int,
kv_channels: int,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.1,
attn_mask_type: str = "causal",
attention_type: str = "self",
tp_size: int = 1,
backend: str = 'transformer_engine') -> None:
super().__init__()
self.norm_factor = norm_factor
self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout
self.attention_type = attention_type
self.qkv_layout = "bs3hd" if attention_type == "self" else "bshd_bs2hd"
self.qkv_layout = "bshd_bshd_bshd"
self.hidden_size_per_attention_head = kv_channels
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.tp_size = tp_size
self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
self.backend = backend
......@@ -185,7 +265,8 @@ class DotProductAttention(paddle.nn.Layer):
def forward(
self,
query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None,
key_layer: paddle.Tensor,
value_layer: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
......@@ -199,26 +280,15 @@ class DotProductAttention(paddle.nn.Layer):
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`.
.. note::
For self attention, :attr:`query_layer` is the `[query, key, value]` tensor
stacked along the 2nd dimension, which must be of shape (:attr:`batch_size`,
:attr:`seq_length`, 3, :attr:`num_attention_heads`, :attr:`size_per_head`).
And :attr:`key_value_layer` is `None`.
For cross attention, :attr:`query_layer` is the `[query]` tensor, which must
be of shape (:attr:`batch_size`, :attr:`seq_length`, :attr:`num_attention_heads`,
:attr:`size_per_head`). And :attr:`key_value_layer` is the `[key, value]` tensor,
which must be of shape (:attr:`batch_size`, :attr:`seq_length`, 2,
:attr:`num_attention_heads`, :attr:`size_per_head`).
Parameters
----------
query_layer : paddle.Tensor
Query tensor.
key_value_layer : paddle.Tensor
Key tensor.
key_layer : paddle.Tensor
Key tensor.
value_layer : paddle.Tensor
Value tensor.
attention_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using attention.
core_attention_bias_type: str, default = `no_bias`
......@@ -231,21 +301,25 @@ class DotProductAttention(paddle.nn.Layer):
backend = self.backend
assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!"
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
if backend == 'transformer_engine':
max_s_q = query_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2],
max_s_q, max_s_kv, query_layer.shape[-1])
key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q,
max_s_kv, query_layer.shape[-1])
is_backend_avail = (self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
])
if is_backend_avail and self.use_fused_attention:
return self._te_forward(query_layer, key_value_layer, attention_mask,
return self._te_forward(query_layer, key_layer, value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero)
warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
backend = 'paddle'
......@@ -256,13 +330,14 @@ class DotProductAttention(paddle.nn.Layer):
if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored.")
return self._pd_forward(query_layer, key_value_layer, attention_mask)
return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
raise AttributeError(f"Backend {backend} is not supported.")
def _te_forward(
self,
query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None,
key_layer: paddle.Tensor,
value_layer: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None,
......@@ -270,10 +345,10 @@ class DotProductAttention(paddle.nn.Layer):
) -> paddle.Tensor:
if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] kv: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
and key_value_layer is None
), "query shape must be [b, s, 3, h, d] for dot product self attention"
# self attention - q: [b, s, h, d] kv: None
assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4
and len(value_layer.shape)
== 4), "q,k,v shape must be [b, s, h, d] for dot product self attention"
max_seqlen = query_layer.shape[1]
if self.attn_mask_type == "causal" or attention_mask is None:
cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
......@@ -283,32 +358,33 @@ class DotProductAttention(paddle.nn.Layer):
cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
qkv_dtype = TE_DType[query_layer.dtype]
output = FusedAttnFuncPackedQKV.apply(query_layer, cu_seqlens, core_attention_bias,
max_seqlen, 1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0,
set_zero, self.qkv_layout,
core_attention_bias_type, self.attn_mask_type,
self.training, self.fused_attention_backend)
output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens,
cu_seqlens, core_attention_bias, max_seqlen, max_seqlen,
1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0, set_zero,
self.qkv_layout, core_attention_bias_type,
self.attn_mask_type, self.training,
self.fused_attention_backend)
elif self.attention_type == "cross":
# cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d]
# cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d]
assert (
len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
and key_value_layer.shape[2] == 2
), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \
len(query_layer.shape) == 4 and len(key_layer.shape) == 4
and len(value_layer.shape) == 4
), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \
"for dot product cross attention"
assert (attention_mask
is not None), "attention_mask must be provided for cross attention"
max_seqlen_q = query_layer.shape[1]
max_seqlen_kv = key_value_layer.shape[1]
max_seqlen_kv = key_layer.shape[1]
cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
qkv_dtype = TE_DType[query_layer.dtype]
output = FusedAttnFuncPackedKV.apply(query_layer, key_value_layer, cu_seqlens_q,
cu_seqlens_kv, core_attention_bias, max_seqlen_q,
max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0,
set_zero, self.qkv_layout,
core_attention_bias_type, self.attn_mask_type,
self.training, self.fused_attention_backend)
output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q,
cu_seqlens_kv, core_attention_bias, max_seqlen_q,
max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0, set_zero,
self.qkv_layout, core_attention_bias_type,
self.attn_mask_type, self.training,
self.fused_attention_backend)
else:
raise ValueError("attention_type must be one of ['self', 'cross']")
return output
......@@ -316,28 +392,14 @@ class DotProductAttention(paddle.nn.Layer):
def _pd_forward(
self,
query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None,
key_layer: paddle.Tensor,
value_layer: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] k: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
and key_value_layer is None
), "query shape must be [b, s, 3, h, d] for dot product self attention"
q = query_layer[:, :, 0]
k = query_layer[:, :, 1]
v = query_layer[:, :, 2]
elif self.attention_type == "cross":
# cross attention - q: [b, s, h, d] kv: [b, s, 2, h, d]
assert (
len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
and key_value_layer.shape[2] == 2
), f"query shape must be [b, s, h, d] and key_value shape must be [b, s, 2, h, d]" \
f"for dot product cross attention. The actual shape is q: {query_layer.shape}" \
f"kv: {key_value_layer.shape}"
q = query_layer
k = key_value_layer[:, :, 0]
v = key_value_layer[:, :, 1]
q = query_layer
k = repeat_kv(key_layer, self.num_queries_per_key_value)
v = repeat_kv(value_layer, self.num_queries_per_key_value)
q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -404,6 +466,14 @@ class MultiHeadAttention(paddle.nn.Layer):
if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
rng_state_name : str, default = `local_seed`
Controls the rng state used for dropout on attention probs. The
specified rng should be set different seeds for different TP ranks.
......@@ -430,6 +500,7 @@ class MultiHeadAttention(paddle.nn.Layer):
set_parallel_mode: bool = False,
sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None,
num_gqa_groups: Optional[int] = None,
rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine',
) -> None:
......@@ -450,19 +521,25 @@ class MultiHeadAttention(paddle.nn.Layer):
self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_size_per_attention_head = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.set_parallel_mode = set_parallel_mode
self.rng_state_name = rng_state_name
self.backend = backend
self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
assert (self.num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
assert (self.num_gqa_groups % self.tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self":
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
3 * hidden_size,
hidden_size + 2 * self.hidden_size_kv,
eps=layernorm_epsilon,
weight_attr=self.weight_attr,
bias_attr=self.bias_attr,
......@@ -476,7 +553,7 @@ class MultiHeadAttention(paddle.nn.Layer):
else:
self.qkv = Linear(
hidden_size,
3 * hidden_size,
hidden_size + 2 * self.hidden_size_kv,
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
......@@ -513,7 +590,7 @@ class MultiHeadAttention(paddle.nn.Layer):
)
self.key_value = Linear(
hidden_size,
2 * hidden_size,
2 * self.hidden_size_kv,
self.weight_attr,
self.bias_attr,
parallel_mode=qkv_parallel_mode,
......@@ -524,10 +601,13 @@ class MultiHeadAttention(paddle.nn.Layer):
# Attention.
self.core_attention = DotProductAttention(
norm_factor,
self.num_attention_heads,
self.hidden_size_per_attention_head,
self.num_gqa_groups,
attention_dropout,
attn_mask_type=attn_mask_type,
attention_type=self.attention_type,
tp_size=self.tp_size,
backend=self.backend,
)
......@@ -619,18 +699,37 @@ class MultiHeadAttention(paddle.nn.Layer):
is_first_microbatch=is_first_microbatch,
)
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size]
num_queries_per_key_value = (self.num_attention_heads_per_partition //
self.num_gqa_groups_per_partition)
# [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
-1, max_seq_len, 3, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
-1, max_seq_len, (
num_queries_per_key_value +
2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
])
# [b, s_q, (h/ng+2), ng, d]
# --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
query_layer, key_layer, value_layer = paddle.split(
mixed_qkv_layer,
num_or_sections=(num_queries_per_key_value, 1, 1),
axis=2,
)
# query: -> [b, s, h, d]
# key, value: -> [b, s, ng, d]
query_layer, key_layer, value_layer = (x.reshape(
shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
for x in (query_layer, key_layer, value_layer))
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention:
context_layer = recompute(
self.core_attention,
mixed_qkv_layer,
None,
query_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
......@@ -639,8 +738,9 @@ class MultiHeadAttention(paddle.nn.Layer):
)
else:
context_layer = self.core_attention(
query_layer=mixed_qkv_layer,
key_value_layer=None,
query_layer=query_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
......@@ -654,10 +754,17 @@ class MultiHeadAttention(paddle.nn.Layer):
)
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[
-1, max_seq_len, 2, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
])
# [b, s_kv, 2 * ng, head_size]
# --> 2 [b, s_kv, ng, head_size]
key_layer, value_layer = paddle.split(
mixed_kv_layer,
num_or_sections=2,
axis=2,
)
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
hidden_states,
......@@ -673,6 +780,7 @@ class MultiHeadAttention(paddle.nn.Layer):
is_first_microbatch=is_first_microbatch,
)
# [b, s, hidden_size] --> [b, s, h, d]
query_layer = query_layer.reshape(shape=[
-1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head
......@@ -682,7 +790,8 @@ class MultiHeadAttention(paddle.nn.Layer):
context_layer = recompute(
self.core_attention,
query_layer,
mixed_kv_layer,
key_layer,
value_layer,
attention_mask,
core_attention_bias_type,
core_attention_bias,
......@@ -692,7 +801,8 @@ class MultiHeadAttention(paddle.nn.Layer):
else:
context_layer = self.core_attention(
query_layer=query_layer,
key_value_layer=mixed_kv_layer,
key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
......
......@@ -27,6 +27,14 @@ class TransformerLayer(paddle.nn.Layer):
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
num_gqa_groups : Optional[int], default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
......@@ -97,6 +105,7 @@ class TransformerLayer(paddle.nn.Layer):
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
num_gqa_groups: Optional[int] = None,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
......@@ -153,6 +162,7 @@ class TransformerLayer(paddle.nn.Layer):
"set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel,
"tp_group": tp_group,
"num_gqa_groups": num_gqa_groups,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment