Unverified Commit bd7fd0a6 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Support GQA (#595)



* use separate qkv
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* add support for GQA
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* change rtol
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix reshape issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent e531cd2f
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Test TE Paddle Layer-level APIs""" """Test TE Paddle Layer-level APIs"""
import math
import os import os
from utils import assert_allclose, is_fused_attention_supported from utils import assert_allclose, is_fused_attention_supported
...@@ -785,7 +784,7 @@ class TestLayerNormMLP: ...@@ -785,7 +784,7 @@ class TestLayerNormMLP:
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]]) @pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('attn_type', ['self', 'cross']) @pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
...@@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
head_size=head_size, head_size=head_size,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
self_attn_qkv_input = paddle.normal(mean=0.0, attn_q_input = paddle.normal(mean=0.0, std=0.02,
std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
shape=(bs, q_seqlen, 3, num_heads, attn_k_input = paddle.normal(mean=0.0, std=0.02,
head_size)).astype(math_dtype) shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
cross_attn_q_input = paddle.normal(mean=0.0, attn_v_input = paddle.normal(mean=0.0, std=0.02,
std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
shape=(bs, q_seqlen, num_heads,
head_size)).astype(math_dtype)
cross_attn_kv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, kv_seqlen, 2, num_heads,
head_size)).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32') q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,), kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
...@@ -841,58 +834,37 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -841,58 +834,37 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
for i in range(0, bs): for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
norm_factor = math.sqrt(hidden_size // num_heads) head_size = hidden_size // num_heads
layer_te = te.DotProductAttention(norm_factor, layer_te = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0, attention_dropout=0.0,
attn_mask_type=mask_type, attn_mask_type=mask_type,
attention_type=attn_type, attention_type=attn_type,
backend='transformer_engine') backend='transformer_engine')
layer_pd = te.DotProductAttention(norm_factor, layer_pd = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0, attention_dropout=0.0,
attn_mask_type=mask_type, attn_mask_type=mask_type,
attention_type=attn_type, attention_type=attn_type,
backend='paddle') backend='paddle')
def calc_attn_output_and_grad(layer, q, kv, mask, dout): def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False) _q = paddle.to_tensor(q, stop_gradient=False)
_kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None _k = paddle.to_tensor(k, stop_gradient=False)
_v = paddle.to_tensor(v, stop_gradient=False)
out = layer(_q, _kv, mask) out = layer(_q, _k, _v, mask)
out.backward(dout) out.backward(dout)
return out, _q.grad, _kv.grad if _kv is not None else None return out, _q.grad, _k.grad, _v.grad
if attn_type == 'self': out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask, attn_v_input, attn_mask, grad_out)
grad_out) out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None, layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0) valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs): for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :] valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
q_grad = qkv_grad[:, :, 0]
k_grad = qkv_grad[:, :, 1]
v_grad = qkv_grad[:, :, 2]
q_grad_ref = qkv_grad_ref[:, :, 0]
k_grad_ref = qkv_grad_ref[:, :, 1]
v_grad_ref = qkv_grad_ref[:, :, 2]
else:
out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input,
cross_attn_kv_input, attn_mask, grad_out)
out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input,
cross_attn_kv_input, attn_mask,
grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
k_grad = kv_grad[:, :, 0]
v_grad = kv_grad[:, :, 1]
k_grad_ref = kv_grad_ref[:, :, 0]
v_grad_ref = kv_grad_ref[:, :, 1]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
...@@ -910,17 +882,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, ...@@ -910,17 +882,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16']) @pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
output_layernorm, return_layernorm_output): math_dtype, output_layernorm, return_layernorm_output):
""" """
Test Transformer Encoder Layer Test Transformer Encoder Layer
""" """
...@@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=num_heads, num_heads=num_heads,
num_gqa_groups=num_heads, num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
...@@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size, layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size, ffn_hidden_size,
num_heads, num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps, layernorm_epsilon=eps,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
...@@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size, layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size, ffn_hidden_size,
num_heads, num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps, layernorm_epsilon=eps,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
...@@ -1088,8 +1063,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1088,8 +1063,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('bs', [1, 2, 8]) @pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]]) @pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]]) @pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]]) @pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False]) @pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding']) @pytest.mark.parametrize('mask_type', ['causal', 'padding'])
...@@ -1097,9 +1073,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1097,9 +1073,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('output_layernorm', [True, False]) @pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False]) @pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False]) @pytest.mark.parametrize('recompute_core_attention', [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias, def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype, has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
output_layernorm, return_layernorm_output, math_dtype, output_layernorm, return_layernorm_output,
recompute_core_attention): recompute_core_attention):
""" """
Test Transformer Decoder Layer Test Transformer Decoder Layer
...@@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported # Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported( if not is_fused_attention_supported(
num_heads=num_heads, num_heads=num_heads,
num_gqa_groups=num_heads, num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen, q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen, kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads, head_size=hidden_size // num_heads,
dtype=math_dtype, dtype=math_dtype,
dropout=0.0, dropout=0.0,
qkv_layout="bs3hd", qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bs2hd",
bias_type="no_bias", bias_type="no_bias",
mask_type=mask_type, mask_type=mask_type,
): ):
pytest.skip("cuDNN fused attention is not supported") pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) encoder_input = paddle.normal(mean=0.0, std=0.1,
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype) shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
encoder_output = paddle.normal(mean=0.0, std=0.1,
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool') attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32') grad_out = paddle.normal(mean=0.0, std=0.01,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
encoder_output = paddle.round(encoder_output * 1000) / 1000
grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs): for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0 grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype) grad_out = grad_out.astype(math_dtype)
...@@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size, layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size, ffn_hidden_size,
num_heads, num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps, layernorm_epsilon=eps,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
...@@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size, layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size, ffn_hidden_size,
num_heads, num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps, layernorm_epsilon=eps,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
...@@ -1319,7 +1293,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1319,7 +1293,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad, assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T, layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol, rtol=rtol,
atol=0.1) atol=atol)
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad, assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T, layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol, rtol=rtol,
...@@ -1328,7 +1302,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, ...@@ -1328,7 +1302,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
if output_layernorm: if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad, assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad, layer_pd.self_attention.qkv.bias.grad,
rtol=0.01, rtol=0.5,
atol=0.6) atol=0.6)
else: else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad, assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
......
...@@ -5,6 +5,12 @@ ...@@ -5,6 +5,12 @@
import struct import struct
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -39,6 +45,8 @@ from transformer_engine.paddle.cpp_extensions import ( ...@@ -39,6 +45,8 @@ from transformer_engine.paddle.cpp_extensions import (
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked, fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked, fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
scaled_softmax_forward, scaled_softmax_forward,
scaled_softmax_backward, scaled_softmax_backward,
scaled_masked_softmax_forward, scaled_masked_softmax_forward,
...@@ -594,6 +602,7 @@ class TestFusedAttn: ...@@ -594,6 +602,7 @@ class TestFusedAttn:
self.q = _random(self.q_shape) self.q = _random(self.q_shape)
if self.attn_mode == "self_attn": if self.attn_mode == "self_attn":
assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen"
self.kv = self.q self.kv = self.q
else: else:
self.kv = _random(self.kv_shape) self.kv = _random(self.kv_shape)
...@@ -774,6 +783,70 @@ class TestFusedAttn: ...@@ -774,6 +783,70 @@ class TestFusedAttn:
return out, q_grad, k_grad, v_grad return out, q_grad, k_grad, v_grad
def _get_fused_attention_with_separate_qkv(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = "bshd_bshd_bshd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, rng_state = fused_attn_fwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
dq, dk, dv, _ = fused_attn_bwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
fused_attention_backend=fused_attention_backend,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
return out, dq, dk, dv
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES) @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16']) @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False]) @pytest.mark.parametrize('is_causal_masking', [True, False])
...@@ -857,6 +930,35 @@ class TestFusedAttn: ...@@ -857,6 +930,35 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype,
is_causal_masking):
"""
test flash attention forward + backward with separate qkv inputs
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax: class TestSoftmax:
""" """
......
...@@ -792,6 +792,189 @@ def fused_attn_bwd_kvpacked( ...@@ -792,6 +792,189 @@ def fused_attn_bwd_kvpacked(
return dq, dkv, dbias return dq, dkv, dbias
def fused_attn_fwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for unpacked QKV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [
1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training:
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(shape=[
2,
], dtype=paddle.int64)
# execute kernel
tex.te_fused_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
rng_elts_per_thread,
)
return out, softmax_aux, rng_state
def fused_attn_bwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dk = paddle.empty(shape=k.shape, dtype=k.dtype)
dv = paddle.empty(shape=v.shape, dtype=v.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dk,
dv,
dbias,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dq, dk, dv, dbias
def scaled_softmax_forward( def scaled_softmax_forward(
inp: paddle.Tensor, inp: paddle.Tensor,
scale_factor: float, scale_factor: float,
......
...@@ -864,6 +864,183 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -864,6 +864,183 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
} }
void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, const int64_t qkv_type,
int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(O);
} else { // TODO: support fp8
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::Tensor &O, const paddle::Tensor &dO,
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dK, // NOLINT
paddle::Tensor &dV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQ = MakeNvteTensor(dQ);
te_dK = MakeNvteTensor(dK);
te_dV = MakeNvteTensor(dV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input, std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input,
float scale_factor) { float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
...@@ -1316,6 +1493,33 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked) ...@@ -1316,6 +1493,33 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
{paddle::Optional("_dBias"), paddle::Optional("dBias")}}) {paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked));
PD_BUILD_OP(te_fused_attn_fwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t", "rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd));
PD_BUILD_OP(te_fused_attn_bwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK",
"_dV", paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t"})
.SetInplaceMap({{"_dQ", "dQ"},
{"_dK", "dK"},
{"_dV", "dV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd));
PD_BUILD_OP(te_scaled_softmax_forward) PD_BUILD_OP(te_scaled_softmax_forward)
.Inputs({"input"}) .Inputs({"input"})
.Outputs({"softmax_results"}) .Outputs({"softmax_results"})
......
...@@ -22,6 +22,8 @@ from ..cpp_extensions import ( ...@@ -22,6 +22,8 @@ from ..cpp_extensions import (
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked, fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked, fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
mask_to_cu_seqlens, mask_to_cu_seqlens,
) )
from ..distributed import get_tp_group_and_world_size, track_rng_state from ..distributed import get_tp_group_and_world_size, track_rng_state
...@@ -31,6 +33,20 @@ from ..recompute import recompute ...@@ -31,6 +33,20 @@ from ..recompute import recompute
__all__ = ["DotProductAttention", "MultiHeadAttention"] __all__ = ["DotProductAttention", "MultiHeadAttention"]
def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
"""
Used to repeat the key and value states for GQA.
The hidden states go from (batch, seqlen, num_gqa_groups, head_size)
to (batch, seqlen, num_heads, head_size)
"""
batch, seqlen, num_gqa_groups, head_size = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1])
return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size])
class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input""" """Function for FusedAttention with packed QKV input"""
...@@ -130,6 +146,50 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): ...@@ -130,6 +146,50 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
return (dq, dkv, None, None, rest[0]) return (dq, dkv, None, None, rest[0])
class FusedAttnFunc(paddle.autograd.PyLayer):
"""Function for FusedAttention with separate Q, K, V tensors"""
@staticmethod
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
attn_mask_type, is_training, fused_attention_backend):
"""Forward function for FusedAttention with separate Q, K, V tensors"""
out, softmax_aux, rng_state = fused_attn_fwd(q, k, v, cu_seqlens_q, cu_seqlens_kv,
is_training, max_seqlen_q, max_seqlen_kv,
qkv_dtype, fused_attention_backend, attn_bias,
attn_scale, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type)
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p
ctx.set_zero = set_zero
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out
@staticmethod
def backward(ctx, d_out):
"""Backward function for FusedAttention with separate Q, K, V tensors"""
q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
dq, dk, dv, *rest = fused_attn_bwd(q, k, v, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
d_out, softmax_aux, ctx.fused_attention_backend,
ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type)
# if no_bias, return dq, dk, dv
if ctx.attn_bias_type == "no_bias":
return (dq, dk, dv, None, None)
# else, return (dq, dk, dv, dbias)
return (dq, dk, dv, None, None, rest[0])
class DotProductAttention(paddle.nn.Layer): class DotProductAttention(paddle.nn.Layer):
""" """
Allows the model to jointly attend to information from different Allows the model to jointly attend to information from different
...@@ -143,31 +203,51 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -143,31 +203,51 @@ class DotProductAttention(paddle.nn.Layer):
Parameters Parameters
---------- ----------
norm_factor : float num_attention_heads: int
normalization factor for the attention scores. number of attention heads in the transformer layer.
kv_channels: int
number of channels in the key and value tensors.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
attention_dropout: float, default = 0.1 attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
attention_type: {'self', 'cross'}, default = `self` attention_type: {'self', 'cross'}, default = `self`
type of attention operation. type of attention operation.
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` backend: {'transformer_engine', 'paddle'}, default = `transformer_engine`
backend to use for attention operation. backend to use for attention operation.
""" """
def __init__(self, def __init__(self,
norm_factor: float, num_attention_heads: int,
kv_channels: int,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_type: str = "self", attention_type: str = "self",
tp_size: int = 1,
backend: str = 'transformer_engine') -> None: backend: str = 'transformer_engine') -> None:
super().__init__() super().__init__()
self.norm_factor = norm_factor
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_type = attention_type self.attention_type = attention_type
self.qkv_layout = "bs3hd" if attention_type == "self" else "bshd_bs2hd" self.qkv_layout = "bshd_bshd_bshd"
self.hidden_size_per_attention_head = kv_channels
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.tp_size = tp_size
self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size)
self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups
self.backend = backend self.backend = backend
...@@ -185,7 +265,8 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -185,7 +265,8 @@ class DotProductAttention(paddle.nn.Layer):
def forward( def forward(
self, self,
query_layer: paddle.Tensor, query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None, key_layer: paddle.Tensor,
value_layer: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
...@@ -199,26 +280,15 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -199,26 +280,15 @@ class DotProductAttention(paddle.nn.Layer):
Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type`
is set to `"causal"`. is set to `"causal"`.
.. note::
For self attention, :attr:`query_layer` is the `[query, key, value]` tensor
stacked along the 2nd dimension, which must be of shape (:attr:`batch_size`,
:attr:`seq_length`, 3, :attr:`num_attention_heads`, :attr:`size_per_head`).
And :attr:`key_value_layer` is `None`.
For cross attention, :attr:`query_layer` is the `[query]` tensor, which must
be of shape (:attr:`batch_size`, :attr:`seq_length`, :attr:`num_attention_heads`,
:attr:`size_per_head`). And :attr:`key_value_layer` is the `[key, value]` tensor,
which must be of shape (:attr:`batch_size`, :attr:`seq_length`, 2,
:attr:`num_attention_heads`, :attr:`size_per_head`).
Parameters Parameters
---------- ----------
query_layer : paddle.Tensor query_layer : paddle.Tensor
Query tensor. Query tensor.
key_value_layer : paddle.Tensor key_layer : paddle.Tensor
Key tensor. Key tensor.
value_layer : paddle.Tensor
Value tensor.
attention_mask : Optional[paddle.Tensor], default = `None` attention_mask : Optional[paddle.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using attention. Boolean tensor used to mask out softmax input when not using attention.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = `no_bias`
...@@ -231,21 +301,25 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -231,21 +301,25 @@ class DotProductAttention(paddle.nn.Layer):
backend = self.backend backend = self.backend
assert (key_layer.shape == value_layer.shape), "Keys and values must have the same shape!"
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
if backend == 'transformer_engine': if backend == 'transformer_engine':
max_s_q = query_layer.shape[1] max_s_q = query_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1] max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend( self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type], tex.get_nvte_qkv_layout(self.qkv_layout), AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2], AttnMaskType[self.attn_mask_type], self.attention_dropout, query_layer.shape[-2],
key_value_layer.shape[-2] if key_value_layer is not None else query_layer.shape[-2], key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], max_s_q,
max_s_q, max_s_kv, query_layer.shape[-1]) max_s_kv, query_layer.shape[-1])
is_backend_avail = (self.fused_attention_backend in [ is_backend_avail = (self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"] FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
]) ])
if is_backend_avail and self.use_fused_attention: if is_backend_avail and self.use_fused_attention:
return self._te_forward(query_layer, key_value_layer, attention_mask, return self._te_forward(query_layer, key_layer, value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero) core_attention_bias_type, core_attention_bias, set_zero)
warnings.warn("Fused attention is not enabled, falling back to Paddle backend") warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
backend = 'paddle' backend = 'paddle'
...@@ -256,13 +330,14 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -256,13 +330,14 @@ class DotProductAttention(paddle.nn.Layer):
if core_attention_bias_type != "no_bias": if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. " warnings.warn("Paddle backend dot product attention does not support bias yet. "
"Bias will be ignored.") "Bias will be ignored.")
return self._pd_forward(query_layer, key_value_layer, attention_mask) return self._pd_forward(query_layer, key_layer, value_layer, attention_mask)
raise AttributeError(f"Backend {backend} is not supported.") raise AttributeError(f"Backend {backend} is not supported.")
def _te_forward( def _te_forward(
self, self,
query_layer: paddle.Tensor, query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None, key_layer: paddle.Tensor,
value_layer: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[paddle.Tensor] = None, core_attention_bias: Optional[paddle.Tensor] = None,
...@@ -270,10 +345,10 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -270,10 +345,10 @@ class DotProductAttention(paddle.nn.Layer):
) -> paddle.Tensor: ) -> paddle.Tensor:
if self.attention_type == "self": if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] kv: None # self attention - q: [b, s, h, d] kv: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3 assert (len(query_layer.shape) == 4 and len(key_layer.shape) == 4
and key_value_layer is None and len(value_layer.shape)
), "query shape must be [b, s, 3, h, d] for dot product self attention" == 4), "q,k,v shape must be [b, s, h, d] for dot product self attention"
max_seqlen = query_layer.shape[1] max_seqlen = query_layer.shape[1]
if self.attn_mask_type == "causal" or attention_mask is None: if self.attn_mask_type == "causal" or attention_mask is None:
cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1], cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
...@@ -283,32 +358,33 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -283,32 +358,33 @@ class DotProductAttention(paddle.nn.Layer):
cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False) cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
output = FusedAttnFuncPackedQKV.apply(query_layer, cu_seqlens, core_attention_bias, output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens,
max_seqlen, 1.0 / self.norm_factor, qkv_dtype, cu_seqlens, core_attention_bias, max_seqlen, max_seqlen,
self.attention_dropout if self.training else 0.0, 1.0 / self.norm_factor, qkv_dtype,
set_zero, self.qkv_layout, self.attention_dropout if self.training else 0.0, set_zero,
core_attention_bias_type, self.attn_mask_type, self.qkv_layout, core_attention_bias_type,
self.training, self.fused_attention_backend) self.attn_mask_type, self.training,
self.fused_attention_backend)
elif self.attention_type == "cross": elif self.attention_type == "cross":
# cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d] # cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d]
assert ( assert (
len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5 len(query_layer.shape) == 4 and len(key_layer.shape) == 4
and key_value_layer.shape[2] == 2 and len(value_layer.shape) == 4
), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \ ), "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" \
"for dot product cross attention" "for dot product cross attention"
assert (attention_mask assert (attention_mask
is not None), "attention_mask must be provided for cross attention" is not None), "attention_mask must be provided for cross attention"
max_seqlen_q = query_layer.shape[1] max_seqlen_q = query_layer.shape[1]
max_seqlen_kv = key_value_layer.shape[1] max_seqlen_kv = key_layer.shape[1]
cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
output = FusedAttnFuncPackedKV.apply(query_layer, key_value_layer, cu_seqlens_q, output = FusedAttnFunc.apply(query_layer, key_layer, value_layer, cu_seqlens_q,
cu_seqlens_kv, core_attention_bias, max_seqlen_q, cu_seqlens_kv, core_attention_bias, max_seqlen_q,
max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype, max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0, set_zero,
set_zero, self.qkv_layout, self.qkv_layout, core_attention_bias_type,
core_attention_bias_type, self.attn_mask_type, self.attn_mask_type, self.training,
self.training, self.fused_attention_backend) self.fused_attention_backend)
else: else:
raise ValueError("attention_type must be one of ['self', 'cross']") raise ValueError("attention_type must be one of ['self', 'cross']")
return output return output
...@@ -316,28 +392,14 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -316,28 +392,14 @@ class DotProductAttention(paddle.nn.Layer):
def _pd_forward( def _pd_forward(
self, self,
query_layer: paddle.Tensor, query_layer: paddle.Tensor,
key_value_layer: paddle.Tensor = None, key_layer: paddle.Tensor,
value_layer: paddle.Tensor,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor] = None,
) -> paddle.Tensor: ) -> paddle.Tensor:
if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] k: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
and key_value_layer is None
), "query shape must be [b, s, 3, h, d] for dot product self attention"
q = query_layer[:, :, 0]
k = query_layer[:, :, 1]
v = query_layer[:, :, 2]
elif self.attention_type == "cross":
# cross attention - q: [b, s, h, d] kv: [b, s, 2, h, d]
assert (
len(query_layer.shape) == 4 and len(key_value_layer.shape) == 5
and key_value_layer.shape[2] == 2
), f"query shape must be [b, s, h, d] and key_value shape must be [b, s, 2, h, d]" \
f"for dot product cross attention. The actual shape is q: {query_layer.shape}" \
f"kv: {key_value_layer.shape}"
q = query_layer q = query_layer
k = key_value_layer[:, :, 0] k = repeat_kv(key_layer, self.num_queries_per_key_value)
v = key_value_layer[:, :, 1] v = repeat_kv(value_layer, self.num_queries_per_key_value)
q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -404,6 +466,14 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -404,6 +466,14 @@ class MultiHeadAttention(paddle.nn.Layer):
if set to `True`, uses sequence parallelism. if set to `True`, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = `None`
tensor parallel process group. tensor parallel process group.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
rng_state_name : str, default = `local_seed` rng_state_name : str, default = `local_seed`
Controls the rng state used for dropout on attention probs. The Controls the rng state used for dropout on attention probs. The
specified rng should be set different seeds for different TP ranks. specified rng should be set different seeds for different TP ranks.
...@@ -430,6 +500,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -430,6 +500,7 @@ class MultiHeadAttention(paddle.nn.Layer):
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
num_gqa_groups: Optional[int] = None,
rng_state_name: str = 'local_seed', rng_state_name: str = 'local_seed',
backend: str = 'transformer_engine', backend: str = 'transformer_engine',
) -> None: ) -> None:
...@@ -450,19 +521,25 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -450,19 +521,25 @@ class MultiHeadAttention(paddle.nn.Layer):
self.sequence_parallel = self.tensor_parallel and sequence_parallel self.sequence_parallel = self.tensor_parallel and sequence_parallel
self.hidden_size_per_attention_head = hidden_size // num_attention_heads self.hidden_size_per_attention_head = hidden_size // num_attention_heads
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
self.rng_state_name = rng_state_name self.rng_state_name = rng_state_name
self.backend = backend self.backend = backend
self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size)
self.num_gqa_groups = (num_attention_heads if num_gqa_groups is None else num_gqa_groups)
assert (self.num_attention_heads % self.num_gqa_groups == 0
), "The number of attention heads must be divisible by the number of GQA groups!"
assert (self.num_gqa_groups % self.tp_size == 0
), "The number of GQA groups must be divisible by tensor parallel size!"
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads)
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self": if self.attention_type == "self":
if self.input_layernorm: if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear( self.layernorm_qkv = LayerNormLinear(
hidden_size, hidden_size,
3 * hidden_size, hidden_size + 2 * self.hidden_size_kv,
eps=layernorm_epsilon, eps=layernorm_epsilon,
weight_attr=self.weight_attr, weight_attr=self.weight_attr,
bias_attr=self.bias_attr, bias_attr=self.bias_attr,
...@@ -476,7 +553,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -476,7 +553,7 @@ class MultiHeadAttention(paddle.nn.Layer):
else: else:
self.qkv = Linear( self.qkv = Linear(
hidden_size, hidden_size,
3 * hidden_size, hidden_size + 2 * self.hidden_size_kv,
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
...@@ -513,7 +590,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -513,7 +590,7 @@ class MultiHeadAttention(paddle.nn.Layer):
) )
self.key_value = Linear( self.key_value = Linear(
hidden_size, hidden_size,
2 * hidden_size, 2 * self.hidden_size_kv,
self.weight_attr, self.weight_attr,
self.bias_attr, self.bias_attr,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
...@@ -524,10 +601,13 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -524,10 +601,13 @@ class MultiHeadAttention(paddle.nn.Layer):
# Attention. # Attention.
self.core_attention = DotProductAttention( self.core_attention = DotProductAttention(
norm_factor, self.num_attention_heads,
self.hidden_size_per_attention_head,
self.num_gqa_groups,
attention_dropout, attention_dropout,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_type=self.attention_type, attention_type=self.attention_type,
tp_size=self.tp_size,
backend=self.backend, backend=self.backend,
) )
...@@ -619,18 +699,37 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -619,18 +699,37 @@ class MultiHeadAttention(paddle.nn.Layer):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
# [b, s_q, 3 * hidden_size] --> [b, s_q, 3, num_heads, head_size] num_queries_per_key_value = (self.num_attention_heads_per_partition //
self.num_gqa_groups_per_partition)
# [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d]
mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[ mixed_qkv_layer = mixed_qkv_layer.reshape(shape=[
-1, max_seq_len, 3, self.num_attention_heads_per_partition, -1, max_seq_len, (
self.hidden_size_per_attention_head num_queries_per_key_value +
2), self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
]) ])
# [b, s_q, (h/ng+2), ng, d]
# --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d]
query_layer, key_layer, value_layer = paddle.split(
mixed_qkv_layer,
num_or_sections=(num_queries_per_key_value, 1, 1),
axis=2,
)
# query: -> [b, s, h, d]
# key, value: -> [b, s, ng, d]
query_layer, key_layer, value_layer = (x.reshape(
shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head])
for x in (query_layer, key_layer, value_layer))
with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name):
if recompute_core_attention: if recompute_core_attention:
context_layer = recompute( context_layer = recompute(
self.core_attention, self.core_attention,
mixed_qkv_layer, query_layer,
None, key_layer,
value_layer,
attention_mask, attention_mask,
core_attention_bias_type, core_attention_bias_type,
core_attention_bias, core_attention_bias,
...@@ -639,8 +738,9 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -639,8 +738,9 @@ class MultiHeadAttention(paddle.nn.Layer):
) )
else: else:
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer=mixed_qkv_layer, query_layer=query_layer,
key_value_layer=None, key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask, attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
...@@ -654,10 +754,17 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -654,10 +754,17 @@ class MultiHeadAttention(paddle.nn.Layer):
) )
# [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size]
mixed_kv_layer = mixed_kv_layer.reshape(shape=[ mixed_kv_layer = mixed_kv_layer.reshape(shape=[
-1, max_seq_len, 2, self.num_attention_heads_per_partition, 0, 0, 2 * self.num_gqa_groups_per_partition, self.hidden_size_per_attention_head
self.hidden_size_per_attention_head
]) ])
# [b, s_kv, 2 * ng, head_size]
# --> 2 [b, s_kv, ng, head_size]
key_layer, value_layer = paddle.split(
mixed_kv_layer,
num_or_sections=2,
axis=2,
)
if self.input_layernorm: if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
hidden_states, hidden_states,
...@@ -673,6 +780,7 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -673,6 +780,7 @@ class MultiHeadAttention(paddle.nn.Layer):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
) )
# [b, s, hidden_size] --> [b, s, h, d]
query_layer = query_layer.reshape(shape=[ query_layer = query_layer.reshape(shape=[
-1, max_seq_len, self.num_attention_heads_per_partition, -1, max_seq_len, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head self.hidden_size_per_attention_head
...@@ -682,7 +790,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -682,7 +790,8 @@ class MultiHeadAttention(paddle.nn.Layer):
context_layer = recompute( context_layer = recompute(
self.core_attention, self.core_attention,
query_layer, query_layer,
mixed_kv_layer, key_layer,
value_layer,
attention_mask, attention_mask,
core_attention_bias_type, core_attention_bias_type,
core_attention_bias, core_attention_bias,
...@@ -692,7 +801,8 @@ class MultiHeadAttention(paddle.nn.Layer): ...@@ -692,7 +801,8 @@ class MultiHeadAttention(paddle.nn.Layer):
else: else:
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer=query_layer, query_layer=query_layer,
key_value_layer=mixed_kv_layer, key_layer=key_layer,
value_layer=value_layer,
attention_mask=attention_mask, attention_mask=attention_mask,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
......
...@@ -27,6 +27,14 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -27,6 +27,14 @@ class TransformerLayer(paddle.nn.Layer):
intermediate size to which input samples are projected. intermediate size to which input samples are projected.
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
num_gqa_groups : Optional[int], default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_epsilon : float, default = 1e-5 layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization a value added to the denominator of layer normalization
for numerical stability. for numerical stability.
...@@ -97,6 +105,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -97,6 +105,7 @@ class TransformerLayer(paddle.nn.Layer):
hidden_size: int, hidden_size: int,
ffn_hidden_size: int, ffn_hidden_size: int,
num_attention_heads: int, num_attention_heads: int,
num_gqa_groups: Optional[int] = None,
layernorm_epsilon: float = 1e-5, layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1, hidden_dropout: float = 0.1,
attention_dropout: float = 0.1, attention_dropout: float = 0.1,
...@@ -153,6 +162,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -153,6 +162,7 @@ class TransformerLayer(paddle.nn.Layer):
"set_parallel_mode": set_parallel_mode, "set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel, "sequence_parallel": self.sequence_parallel,
"tp_group": tp_group, "tp_group": tp_group,
"num_gqa_groups": num_gqa_groups,
"rng_state_name": attention_dropout_rng_state_name, "rng_state_name": attention_dropout_rng_state_name,
"backend": backend, "backend": backend,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment