"vscode:/vscode.git/clone" did not exist on "ea55f1f52c489535f0d3b583c81529762c9cb5ea"
Unverified Commit 749a7f69 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Support flash attention (#330)



* add flash attn tests
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* update flash attn
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix random seed
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 23f4864d
......@@ -42,6 +42,7 @@ from transformer_engine.paddle.cpp_extensions import (
)
from transformer_engine.paddle.fp8 import is_fp8_available
np.random.seed(10)
paddle.seed(10)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
......@@ -49,6 +50,7 @@ is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
......@@ -641,6 +643,7 @@ class TestFusedAttn:
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
......@@ -671,6 +674,7 @@ class TestFusedAttn:
kv_tensor,
q_cu_seqlen_tensor,
kv_cu_seqlen_tensor,
rng_state,
out,
self.dout,
softmax_aux_tensor,
......@@ -721,6 +725,23 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True])
def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test flash attention forward + backward
"""
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax:
"""
......
......@@ -397,8 +397,10 @@ def fused_attn_fwd_qkvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
......@@ -412,16 +414,6 @@ def fused_attn_fwd_qkvpacked(
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding/causal attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype)
else:
......@@ -460,6 +452,7 @@ def fused_attn_fwd_qkvpacked(
def fused_attn_bwd_qkvpacked(
qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
......@@ -472,10 +465,12 @@ def fused_attn_bwd_qkvpacked(
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input"""
"""Fused Attention BWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3]
d = qkv.shape[4]
......@@ -483,16 +478,6 @@ def fused_attn_bwd_qkvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else:
......@@ -511,6 +496,7 @@ def fused_attn_bwd_qkvpacked(
softmax_aux,
dqkv,
dbias,
rng_state,
b,
h,
d,
......@@ -547,10 +533,12 @@ def fused_attn_fwd_kvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
......@@ -565,16 +553,6 @@ def fused_attn_fwd_kvpacked(
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype)
else:
......@@ -619,6 +597,7 @@ def fused_attn_bwd_kvpacked(
kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor,
d_o: paddle.Tensor,
softmax_aux: paddle.Tensor,
......@@ -632,10 +611,14 @@ def fused_attn_bwd_kvpacked(
bias_type: str = "no_bias",
attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input"""
"""Fused Attention BWD for packed KV input"""
b = cu_seqlens_q.shape[0] - 1
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2]
......@@ -644,16 +627,6 @@ def fused_attn_bwd_kvpacked(
if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
......@@ -676,6 +649,7 @@ def fused_attn_bwd_kvpacked(
dq,
dkv,
dbias,
rng_state,
b,
h,
d,
......
......@@ -575,6 +575,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
const paddle::Tensor &softmax_aux,
paddle::Tensor &dQKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type,
......@@ -610,12 +611,15 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 1;
nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen), static_cast<size_t>(max_seqlen)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens;
......@@ -742,6 +746,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q,
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout,
......@@ -780,12 +785,15 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 1;
nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
......@@ -1084,7 +1092,8 @@ PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias")})
.Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias"),
"rng_state"})
.Outputs({"dQKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
......@@ -1106,7 +1115,7 @@ PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV",
paddle::Optional("_dBias")})
paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment