Unverified Commit bd7fd0a6 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Support GQA (#595)



* use separate qkv
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* add support for GQA
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>

* minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* change rtol
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix reshape issue
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarjaywan <jaywan@nvidia.com>
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
parent e531cd2f
......@@ -3,7 +3,6 @@
# See LICENSE for license information.
"""Test TE Paddle Layer-level APIs"""
import math
import os
from utils import assert_allclose, is_fused_attention_supported
......@@ -785,7 +784,7 @@ class TestLayerNormMLP:
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16], [768, 12]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('attn_type', ['self', 'cross'])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
......@@ -808,24 +807,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
head_size=head_size,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd" if attn_type == "self" else "bshd_bs2hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
self_attn_qkv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, 3, num_heads,
head_size)).astype(math_dtype)
cross_attn_q_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, q_seqlen, num_heads,
head_size)).astype(math_dtype)
cross_attn_kv_input = paddle.normal(mean=0.0,
std=0.02,
shape=(bs, kv_seqlen, 2, num_heads,
head_size)).astype(math_dtype)
attn_q_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
attn_k_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
attn_v_input = paddle.normal(mean=0.0, std=0.02,
shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
......@@ -841,57 +834,36 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
for i in range(0, bs):
attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
norm_factor = math.sqrt(hidden_size // num_heads)
layer_te = te.DotProductAttention(norm_factor,
head_size = hidden_size // num_heads
layer_te = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='transformer_engine')
layer_pd = te.DotProductAttention(norm_factor,
layer_pd = te.DotProductAttention(num_heads,
head_size,
attention_dropout=0.0,
attn_mask_type=mask_type,
attention_type=attn_type,
backend='paddle')
def calc_attn_output_and_grad(layer, q, kv, mask, dout):
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
_kv = paddle.to_tensor(kv, stop_gradient=False) if kv is not None else None
_k = paddle.to_tensor(k, stop_gradient=False)
_v = paddle.to_tensor(v, stop_gradient=False)
out = layer(_q, _kv, mask)
out = layer(_q, _k, _v, mask)
out.backward(dout)
return out, _q.grad, _kv.grad if _kv is not None else None
if attn_type == 'self':
out, qkv_grad, _ = calc_attn_output_and_grad(layer_te, self_attn_qkv_input, None, attn_mask,
grad_out)
out_ref, qkv_grad_ref, _ = calc_attn_output_and_grad(layer_pd, self_attn_qkv_input, None,
attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
q_grad = qkv_grad[:, :, 0]
k_grad = qkv_grad[:, :, 1]
v_grad = qkv_grad[:, :, 2]
q_grad_ref = qkv_grad_ref[:, :, 0]
k_grad_ref = qkv_grad_ref[:, :, 1]
v_grad_ref = qkv_grad_ref[:, :, 2]
else:
out, q_grad, kv_grad = calc_attn_output_and_grad(layer_te, cross_attn_q_input,
cross_attn_kv_input, attn_mask, grad_out)
out_ref, q_grad_ref, kv_grad_ref = calc_attn_output_and_grad(layer_pd, cross_attn_q_input,
cross_attn_kv_input, attn_mask,
grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
return out, _q.grad, _k.grad, _v.grad
k_grad = kv_grad[:, :, 0]
v_grad = kv_grad[:, :, 1]
k_grad_ref = kv_grad_ref[:, :, 0]
v_grad_ref = kv_grad_ref[:, :, 1]
out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
attn_v_input, attn_mask, grad_out)
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
......@@ -910,17 +882,18 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output):
def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output):
"""
Test Transformer Encoder Layer
"""
......@@ -932,13 +905,13 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
......@@ -962,6 +935,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -975,6 +949,7 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -1088,8 +1063,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('bs', [1, 2, 8])
@pytest.mark.parametrize('num_gqa_groups', [1, 4, 16])
@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128], [512, 512]])
@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[512, 512], [1024, 1024]])
@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
@pytest.mark.parametrize('no_wgrad', [True, False])
@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
......@@ -1097,9 +1073,9 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
@pytest.mark.parametrize('output_layernorm', [True, False])
@pytest.mark.parametrize('return_layernorm_output', [True, False])
@pytest.mark.parametrize('recompute_core_attention', [True, False])
def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size, has_bias, no_dbias,
no_wgrad, q_seqlen, kv_seqlen, mask_type, math_dtype,
output_layernorm, return_layernorm_output,
def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
math_dtype, output_layernorm, return_layernorm_output,
recompute_core_attention):
"""
Test Transformer Decoder Layer
......@@ -1112,39 +1088,35 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
num_heads=num_heads,
num_gqa_groups=num_heads,
num_gqa_groups=num_gqa_groups,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
head_size=hidden_size // num_heads,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bs3hd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
if not is_fused_attention_supported(
head_size=hidden_size // num_heads,
num_heads=num_heads,
num_gqa_groups=num_heads,
q_seqlen=q_seqlen,
kv_seqlen=kv_seqlen,
dtype=math_dtype,
dropout=0.0,
qkv_layout="bshd_bs2hd",
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
encoder_output = paddle.uniform(shape=(bs, kv_seqlen, hidden_size), dtype=math_dtype)
encoder_input = paddle.normal(mean=0.0, std=0.1,
shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
encoder_output = paddle.normal(mean=0.0, std=0.1,
shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
kv_actual_seqlen = q_actual_seqlen
attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
grad_out = paddle.normal(mean=0.0, std=0.2, shape=(bs, q_seqlen, hidden_size)).astype('float32')
grad_out = paddle.normal(mean=0.0, std=0.01,
shape=(bs, q_seqlen, hidden_size)).astype('float32')
# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
encoder_output = paddle.round(encoder_output * 1000) / 1000
grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs):
grad_out[i, q_actual_seqlen[i]:, :] = 0
grad_out = grad_out.astype(math_dtype)
......@@ -1155,6 +1127,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_te = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -1168,6 +1141,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
layer_pd = te.TransformerLayer(hidden_size,
ffn_hidden_size,
num_heads,
num_gqa_groups=num_gqa_groups,
layernorm_epsilon=eps,
hidden_dropout=0.0,
attention_dropout=0.0,
......@@ -1319,7 +1293,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
layer_pd.self_attention.layernorm_qkv.weight.grad.T,
rtol=rtol,
atol=0.1)
atol=atol)
assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
layer_pd.inter_attention.layernorm_query.weight.grad.T,
rtol=rtol,
......@@ -1328,7 +1302,7 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, ffn_hidden_size,
if output_layernorm:
assert_allclose(layer_te.self_attention.qkv.bias.grad,
layer_pd.self_attention.qkv.bias.grad,
rtol=0.01,
rtol=0.5,
atol=0.6)
else:
assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
......
......@@ -5,6 +5,12 @@
import struct
from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import numpy as np
import paddle
import paddle.nn.functional as F
......@@ -39,6 +45,8 @@ from transformer_engine.paddle.cpp_extensions import (
fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked,
fused_attn_fwd,
fused_attn_bwd,
scaled_softmax_forward,
scaled_softmax_backward,
scaled_masked_softmax_forward,
......@@ -594,6 +602,7 @@ class TestFusedAttn:
self.q = _random(self.q_shape)
if self.attn_mode == "self_attn":
assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen"
self.kv = self.q
else:
self.kv = _random(self.kv_shape)
......@@ -774,6 +783,70 @@ class TestFusedAttn:
return out, q_grad, k_grad, v_grad
def _get_fused_attention_with_separate_qkv(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
qkv_layout = "bshd_bshd_bshd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
q_seqlen=self.q_seqlen,
kv_seqlen=self.kv_seqlen,
head_size=self.head_size,
dtype=self.dtype,
dropout=self.dropout_prob,
qkv_layout=qkv_layout,
bias_type="no_bias",
mask_type="causal" if self.is_causal_masking else "padding",
)
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, rng_state = fused_attn_fwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
is_training=True,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
dq, dk, dv, _ = fused_attn_bwd(
q_tensor,
k_tensor,
v_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
fused_attention_backend=fused_attention_backend,
max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
attn_mask_type="causal" if self.is_causal_masking else "padding")
return out, dq, dk, dv
@pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True, False])
......@@ -857,6 +930,35 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [False, True])
def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype,
is_causal_masking):
"""
test flash attention forward + backward with separate qkv inputs
"""
if not is_fused_attention_supported(
num_heads=h,
num_gqa_groups=h,
q_seqlen=s,
kv_seqlen=s,
head_size=d,
dtype=dtype,
dropout=0.0,
qkv_layout="bshd_bshd_bshd",
bias_type="no_bias",
mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
......
......@@ -792,6 +792,189 @@ def fused_attn_bwd_kvpacked(
return dq, dkv, dbias
def fused_attn_fwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
is_training: bool,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for unpacked QKV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
if bias_type != "no_bias":
assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias."
assert (Bias.shape == [
1, h, max_seqlen_q, max_seqlen_kv
]), "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else:
out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype)
if is_training:
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype='float32')
else:
raise ValueError("Unsupported fused attention backend.")
else:
softmax_aux = None
rng_state = paddle.empty(shape=[
2,
], dtype=paddle.int64)
# execute kernel
tex.te_fused_attn_fwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
Bias,
out,
softmax_aux,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
rng_elts_per_thread,
)
return out, softmax_aux, rng_state
def fused_attn_bwd(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen_q: int,
max_seqlen_kv: int,
qkv_dtype: tex.DType,
attn_scale: float = None,
dropout: float = 0.0,
set_zero: bool = True,
qkv_layout: str = "bshd_bshd_bshd",
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention BWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
assert (qkv_layout == "bshd_bshd_bshd"
), "Only support bshd_bshd_bshd layout for unpacked QKV input for now."
b = cu_seqlens_q.shape[0] - 1
h = q.shape[-2]
d = q.shape[-1]
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype)
dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype)
else:
dq = paddle.empty(shape=q.shape, dtype=q.dtype)
dk = paddle.empty(shape=k.shape, dtype=k.dtype)
dv = paddle.empty(shape=v.shape, dtype=v.dtype)
if bias_type != "no_bias":
dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype)
else:
dbias = None
# execute kernel
tex.te_fused_attn_bwd(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
o,
d_o,
softmax_aux,
dq,
dk,
dv,
dbias,
rng_state,
b,
h,
d,
max_seqlen_q,
max_seqlen_kv,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
int(qkv_dtype),
)
return dq, dk, dv, dbias
def scaled_softmax_forward(
inp: paddle.Tensor,
scale_factor: float,
......
......@@ -864,6 +864,183 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
bool is_training, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
const std::string &attn_mask_type, const int64_t qkv_type,
int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n");
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32);
te_O = MakeNvteTensor(O);
} else { // TODO: support fp8
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
if ((bias_type != "no_bias") && Bias) {
auto bias_shape = Bias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32);
}
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// allocate memory for workspace and auxiliary output tensors
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
auto *output_s =
reinterpret_cast<transformer_engine::Tensor *>(nvte_aux_tensor_pack.tensors[0]);
output_s->data.dptr = GetOptionalDataPtr(softmax_aux);
// execute the kernel
nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(),
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv,
is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum,
attn_mask_type_enum, workspace.data(), Q.stream());
// destroy tensor wrappers, but not allocated memory
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V,
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv,
const paddle::Tensor &O, const paddle::Tensor &dO,
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dK, // NOLINT
paddle::Tensor &dV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type,
int64_t qkv_type) {
TensorWrapper te_dBias;
if (bias_type != "no_bias" && dBias) {
auto bias_shape = dBias->shape();
std::vector<size_t> shape{bias_shape.begin(), bias_shape.end()};
te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32);
}
auto qkv_dtype = Int2NvteDType(qkv_type);
// construct NVTE tensors
TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV;
if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) {
// BF16 or FP16
te_Q = MakeNvteTensor(Q);
te_K = MakeNvteTensor(K);
te_V = MakeNvteTensor(V);
te_O = MakeNvteTensor(O);
te_dO = MakeNvteTensor(dO);
te_S = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dP = MakeNvteTensor(nullptr, std::vector<size_t>(0), DType::kFloat32);
te_dQ = MakeNvteTensor(dQ);
te_dK = MakeNvteTensor(dK);
te_dV = MakeNvteTensor(dV);
} else {
NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n");
}
// convert strings to enums
NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout);
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// convert auxiliary tensors from forward into NVTETensors
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
te_cu_seqlens_q =
MakeNvteTensor(cu_seqlens_q.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
te_cu_seqlens_kv =
MakeNvteTensor(cu_seqlens_kv.data(), {static_cast<size_t>(b + 1)}, DType::kInt32);
// create workspace
TensorWrapper workspace;
// populate tensors with appropriate shapes and dtypes
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// allocate memory for workspace
auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place());
workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype());
// execute kernel
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_cu_seqlens_q.data(),
te_cu_seqlens_kv.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout,
qkv_layout_enum, bias_type_enum, attn_mask_type_enum, workspace.data(),
Q.stream());
// destroy tensor wrappers
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
}
std::vector<paddle::Tensor> te_scaled_softmax_forward(const paddle::Tensor &input,
float scale_factor) {
NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor");
......@@ -1316,6 +1493,33 @@ PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked));
PD_BUILD_OP(te_fused_attn_fwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O",
paddle::Optional("_softmax_aux"), "_rng_state"})
.Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t", "rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")},
{"_rng_state", "rng_state"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd));
PD_BUILD_OP(te_fused_attn_bwd)
.Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK",
"_dV", paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t",
"max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float",
"qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string",
"qkv_type: int64_t"})
.SetInplaceMap({{"_dQ", "dQ"},
{"_dK", "dK"},
{"_dV", "dV"},
{paddle::Optional("_dBias"), paddle::Optional("dBias")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd));
PD_BUILD_OP(te_scaled_softmax_forward)
.Inputs({"input"})
.Outputs({"softmax_results"})
......
......@@ -27,6 +27,14 @@ class TransformerLayer(paddle.nn.Layer):
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
num_gqa_groups : Optional[int], default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the queries.
GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization
for numerical stability.
......@@ -97,6 +105,7 @@ class TransformerLayer(paddle.nn.Layer):
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
num_gqa_groups: Optional[int] = None,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
......@@ -153,6 +162,7 @@ class TransformerLayer(paddle.nn.Layer):
"set_parallel_mode": set_parallel_mode,
"sequence_parallel": self.sequence_parallel,
"tp_group": tp_group,
"num_gqa_groups": num_gqa_groups,
"rng_state_name": attention_dropout_rng_state_name,
"backend": backend,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment