Unverified Commit 749a7f69 authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Support flash attention (#330)



* add flash attn tests
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* update flash attn
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fix random seed
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 23f4864d
...@@ -42,6 +42,7 @@ from transformer_engine.paddle.cpp_extensions import ( ...@@ -42,6 +42,7 @@ from transformer_engine.paddle.cpp_extensions import (
) )
from transformer_engine.paddle.fp8 import is_fp8_available from transformer_engine.paddle.fp8 import is_fp8_available
np.random.seed(10)
paddle.seed(10) paddle.seed(10)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024), GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)] (16384, 1024, 1024)]
...@@ -49,6 +50,7 @@ is_fp8_supported, reason = is_fp8_available() ...@@ -49,6 +50,7 @@ is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)] SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)] CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16] ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]
...@@ -641,6 +643,7 @@ class TestFusedAttn: ...@@ -641,6 +643,7 @@ class TestFusedAttn:
dqkv, _ = fused_attn_bwd_qkvpacked( dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor, qkv_tensor,
q_cu_seqlen_tensor, q_cu_seqlen_tensor,
rng_state,
out, out,
self.dout, self.dout,
softmax_aux_tensor, softmax_aux_tensor,
...@@ -671,6 +674,7 @@ class TestFusedAttn: ...@@ -671,6 +674,7 @@ class TestFusedAttn:
kv_tensor, kv_tensor,
q_cu_seqlen_tensor, q_cu_seqlen_tensor,
kv_cu_seqlen_tensor, kv_cu_seqlen_tensor,
rng_state,
out, out,
self.dout, self.dout,
softmax_aux_tensor, softmax_aux_tensor,
...@@ -721,6 +725,23 @@ class TestFusedAttn: ...@@ -721,6 +725,23 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="cuDNN fMHA requires Ampere+ GPU")
@pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
@pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
@pytest.mark.parametrize('is_causal_masking', [True])
def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test flash attention forward + backward
"""
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out()
fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out()
assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2)
assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2)
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
class TestSoftmax: class TestSoftmax:
""" """
......
...@@ -397,8 +397,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -397,8 +397,10 @@ def fused_attn_fwd_qkvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input""" """Fused Attention FWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1 assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1] total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3] h = qkv.shape[3]
d = qkv.shape[4] d = qkv.shape[4]
...@@ -412,16 +414,6 @@ def fused_attn_fwd_qkvpacked( ...@@ -412,16 +414,6 @@ def fused_attn_fwd_qkvpacked(
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv." assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding/causal attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero: if set_zero:
out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype) out = paddle.full(shape=[total_seqs, h, d], fill_value=0, dtype=qkv.dtype)
else: else:
...@@ -460,6 +452,7 @@ def fused_attn_fwd_qkvpacked( ...@@ -460,6 +452,7 @@ def fused_attn_fwd_qkvpacked(
def fused_attn_bwd_qkvpacked( def fused_attn_bwd_qkvpacked(
qkv: paddle.Tensor, qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor, cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor, o: paddle.Tensor,
d_o: paddle.Tensor, d_o: paddle.Tensor,
softmax_aux: paddle.Tensor, softmax_aux: paddle.Tensor,
...@@ -472,10 +465,12 @@ def fused_attn_bwd_qkvpacked( ...@@ -472,10 +465,12 @@ def fused_attn_bwd_qkvpacked(
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed QKV input""" """Fused Attention BWD for packed QKV input"""
b = cu_seqlens.shape[0] - 1 assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
b = cu_seqlens.shape[0] - 1
total_seqs = qkv.shape[0] * qkv.shape[1] total_seqs = qkv.shape[0] * qkv.shape[1]
h = qkv.shape[3] h = qkv.shape[3]
d = qkv.shape[4] d = qkv.shape[4]
...@@ -483,16 +478,6 @@ def fused_attn_bwd_qkvpacked( ...@@ -483,16 +478,6 @@ def fused_attn_bwd_qkvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen <= 512) and (d
== 64):
assert (qkv_layout == "qkv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports qkv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero: if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else: else:
...@@ -511,6 +496,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -511,6 +496,7 @@ def fused_attn_bwd_qkvpacked(
softmax_aux, softmax_aux,
dqkv, dqkv,
dbias, dbias,
rng_state,
b, b,
h, h,
d, d,
...@@ -547,10 +533,12 @@ def fused_attn_fwd_kvpacked( ...@@ -547,10 +533,12 @@ def fused_attn_fwd_kvpacked(
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input""" """Fused Attention FWD for packed KV input"""
assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape" ), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1] total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1] total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2] h = q.shape[2]
...@@ -565,16 +553,6 @@ def fused_attn_fwd_kvpacked( ...@@ -565,16 +553,6 @@ def fused_attn_fwd_kvpacked(
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv." assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero: if set_zero:
out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype) out = paddle.full(shape=[total_seqs_q, h, d], fill_value=0, dtype=q.dtype)
else: else:
...@@ -619,6 +597,7 @@ def fused_attn_bwd_kvpacked( ...@@ -619,6 +597,7 @@ def fused_attn_bwd_kvpacked(
kv: paddle.Tensor, kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor, cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor, cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
o: paddle.Tensor, o: paddle.Tensor,
d_o: paddle.Tensor, d_o: paddle.Tensor,
softmax_aux: paddle.Tensor, softmax_aux: paddle.Tensor,
...@@ -632,10 +611,14 @@ def fused_attn_bwd_kvpacked( ...@@ -632,10 +611,14 @@ def fused_attn_bwd_kvpacked(
bias_type: str = "no_bias", bias_type: str = "no_bias",
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Fused Attention FWD for packed KV input""" """Fused Attention BWD for packed KV input"""
b = cu_seqlens_q.shape[0] - 1 assert (qkv_dtype in (tex.DType.kBFloat16,
tex.DType.kFloat16)), "Only support bf16/fp16 for fused attention."
assert (cu_seqlens_q.shape == cu_seqlens_kv.shape
), "cu_seqlens_q and cu_seqlens_kv must have the same shape"
b = cu_seqlens_q.shape[0] - 1
total_seqs_q = q.shape[0] * q.shape[1] total_seqs_q = q.shape[0] * q.shape[1]
total_seqs_kv = kv.shape[0] * kv.shape[1] total_seqs_kv = kv.shape[0] * kv.shape[1]
h = q.shape[2] h = q.shape[2]
...@@ -644,16 +627,6 @@ def fused_attn_bwd_kvpacked( ...@@ -644,16 +627,6 @@ def fused_attn_bwd_kvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
# BF16/FP16 fused attention API
if (qkv_dtype in (tex.DType.kBFloat16, tex.DType.kFloat16)) and (max_seqlen_q <= 512) and (
max_seqlen_kv <= 512) and (d == 64):
assert (qkv_layout == "kv_interleaved" and bias_type == "no_bias"
and (attn_mask_type in ("padding", "causal"))
), """The fused attention currently only supports kv_interleaved layout,
no_bias type, and padding attention mask type."""
else:
assert False, "No support for this dtype and max_seqlen combination."
if set_zero: if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
...@@ -676,6 +649,7 @@ def fused_attn_bwd_kvpacked( ...@@ -676,6 +649,7 @@ def fused_attn_bwd_kvpacked(
dq, dq,
dkv, dkv,
dbias, dbias,
rng_state,
b, b,
h, h,
d, d,
......
...@@ -575,6 +575,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -575,6 +575,7 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
const paddle::Tensor &softmax_aux, const paddle::Tensor &softmax_aux,
paddle::Tensor &dQKV, // NOLINT paddle::Tensor &dQKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs, int64_t b, int64_t h, int64_t d, int64_t total_seqs,
int64_t max_seqlen, float attn_scale, float p_dropout, int64_t max_seqlen, float attn_scale, float p_dropout,
const std::string &qkv_layout, const std::string &bias_type, const std::string &qkv_layout, const std::string &bias_type,
...@@ -610,12 +611,15 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -610,12 +611,15 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 1; nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = output_s->data.shape =
std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h), std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen), static_cast<size_t>(max_seqlen)}); static_cast<size_t>(max_seqlen), static_cast<size_t>(max_seqlen)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data()); output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens; TensorWrapper te_cu_seqlens;
...@@ -742,6 +746,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -742,6 +746,7 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
paddle::Tensor &dQ, // NOLINT paddle::Tensor &dQ, // NOLINT
paddle::Tensor &dKV, // NOLINT paddle::Tensor &dKV, // NOLINT
paddle::optional<paddle::Tensor> &dBias, // NOLINT paddle::optional<paddle::Tensor> &dBias, // NOLINT
paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t b, int64_t h, int64_t d, int64_t total_seqs_q,
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv,
float attn_scale, float p_dropout, const std::string &qkv_layout, float attn_scale, float p_dropout, const std::string &qkv_layout,
...@@ -780,12 +785,15 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -780,12 +785,15 @@ void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
NVTETensorPack nvte_aux_tensor_pack; NVTETensorPack nvte_aux_tensor_pack;
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = 1; nvte_aux_tensor_pack.size = 2;
auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]); auto *output_s = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[0]);
auto *fwd_rng_state = reinterpret_cast<Tensor *>(nvte_aux_tensor_pack.tensors[1]);
output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h), output_s->data.shape = std::vector<size_t>({static_cast<size_t>(b), static_cast<size_t>(h),
static_cast<size_t>(max_seqlen_q), static_cast<size_t>(max_seqlen_q),
static_cast<size_t>(max_seqlen_kv)}); static_cast<size_t>(max_seqlen_kv)});
output_s->data.dptr = const_cast<void *>(softmax_aux.data()); output_s->data.dptr = const_cast<void *>(softmax_aux.data());
fwd_rng_state->data.shape = std::vector<size_t>({2});
fwd_rng_state->data.dptr = const_cast<void *>(rng_state.data());
// create cu_seqlens tensorwrappers // create cu_seqlens tensorwrappers
TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv;
...@@ -1084,7 +1092,8 @@ PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) ...@@ -1084,7 +1092,8 @@ PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) PD_BUILD_OP(te_fused_attn_bwd_qkvpacked)
.Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias")}) .Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias"),
"rng_state"})
.Outputs({"dQKV", paddle::Optional("dBias")}) .Outputs({"dQKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
...@@ -1106,7 +1115,7 @@ PD_BUILD_OP(te_fused_attn_fwd_kvpacked) ...@@ -1106,7 +1115,7 @@ PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
PD_BUILD_OP(te_fused_attn_bwd_kvpacked) PD_BUILD_OP(te_fused_attn_bwd_kvpacked)
.Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV", .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV",
paddle::Optional("_dBias")}) paddle::Optional("_dBias"), "rng_state"})
.Outputs({"dQ", "dKV", paddle::Optional("dBias")}) .Outputs({"dQ", "dKV", paddle::Optional("dBias")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment