Unverified Commit 5f60f82f authored by Shijie's avatar Shijie Committed by GitHub
Browse files

[Paddle] Some refactor and optimization on fused attention (#411)



* fix mask conversion and rng_state
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* refactor fused attn
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* use CUB to do prefix sum
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* fuse dropout add
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* minor changes
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* optimize kernel
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>

* Debug merge errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarShijie Wang <jaywan@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 30d51226
...@@ -667,18 +667,20 @@ class TestFusedAttn: ...@@ -667,18 +667,20 @@ class TestFusedAttn:
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
rng_state = paddle.zeros((2,), dtype=np.int64) fused_attention_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen if (
self.q_seqlen <= 512
and self.kv_seqlen <= 512) else tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
if self.attn_mode == 'self_attn': if self.attn_mode == 'self_attn':
out, softmax_aux_tensor = fused_attn_fwd_qkvpacked( out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
qkv_tensor, qkv_tensor,
q_cu_seqlen_tensor, q_cu_seqlen_tensor,
rng_state,
is_training=True, is_training=True,
max_seqlen=self.q_seqlen, max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype, qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None, Bias=None,
attn_scale=self.scaling_factor, attn_scale=self.scaling_factor,
dropout=self.dropout_prob, dropout=self.dropout_prob,
...@@ -693,6 +695,7 @@ class TestFusedAttn: ...@@ -693,6 +695,7 @@ class TestFusedAttn:
softmax_aux_tensor, softmax_aux_tensor,
max_seqlen=self.q_seqlen, max_seqlen=self.q_seqlen,
qkv_dtype=qkv_dtype, qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
attn_scale=self.scaling_factor, attn_scale=self.scaling_factor,
dropout=self.dropout_prob, dropout=self.dropout_prob,
set_zero=False, set_zero=False,
...@@ -701,15 +704,16 @@ class TestFusedAttn: ...@@ -701,15 +704,16 @@ class TestFusedAttn:
k_grad = dqkv[:, :, 1, :, :] k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :] v_grad = dqkv[:, :, 2, :, :]
else: # attn_mode == 'cross_attn' else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor = fused_attn_fwd_kvpacked(q_tensor, out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked(
q_tensor,
kv_tensor, kv_tensor,
q_cu_seqlen_tensor, q_cu_seqlen_tensor,
kv_cu_seqlen_tensor, kv_cu_seqlen_tensor,
rng_state,
is_training=True, is_training=True,
max_seqlen_q=self.q_seqlen, max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen, max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype, qkv_dtype=qkv_dtype,
fused_attention_backend=fused_attention_backend,
Bias=None, Bias=None,
attn_scale=self.scaling_factor, attn_scale=self.scaling_factor,
dropout=self.dropout_prob, dropout=self.dropout_prob,
...@@ -722,6 +726,7 @@ class TestFusedAttn: ...@@ -722,6 +726,7 @@ class TestFusedAttn:
out, out,
self.dout, self.dout,
softmax_aux_tensor, softmax_aux_tensor,
fused_attention_backend=fused_attention_backend,
max_seqlen_q=self.q_seqlen, max_seqlen_q=self.q_seqlen,
max_seqlen_kv=self.kv_seqlen, max_seqlen_kv=self.kv_seqlen,
qkv_dtype=qkv_dtype, qkv_dtype=qkv_dtype,
......
...@@ -52,3 +52,27 @@ GemmParallelModes = ("row", "column", None) ...@@ -52,3 +52,27 @@ GemmParallelModes = ("row", "column", None)
dist_group_type = paddle.distributed.collective.Group dist_group_type = paddle.distributed.collective.Group
RecomputeFunctionNames = ('unpack', 'backward') RecomputeFunctionNames = ('unpack', 'backward')
QKVLayout = {
"not_interleaved": tex.NVTE_QKV_Layout.NVTE_NOT_INTERLEAVED,
"qkv_interleaved": tex.NVTE_QKV_Layout.NVTE_QKV_INTERLEAVED,
"kv_interleaved": tex.NVTE_QKV_Layout.NVTE_KV_INTERLEAVED,
}
AttnBiasType = {
"no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS,
"pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS,
"post_scale_bias": tex.NVTE_Bias_Type.NVTE_POST_SCALE_BIAS,
}
AttnMaskType = {
"no_mask": tex.NVTE_Mask_Type.NVTE_NO_MASK,
"padding": tex.NVTE_Mask_Type.NVTE_PADDING_MASK,
"causal": tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK,
}
FusedAttnBackend = {
"F16_max512_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen,
"F16_arbitrary_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
"No_Backend": tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
}
...@@ -7,9 +7,12 @@ import math ...@@ -7,9 +7,12 @@ import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import paddle import paddle
import transformer_engine_paddle as tex import transformer_engine_paddle as tex
from .constants import TE_DType, FP8FwdTensors, FP8BwdTensors from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta from .fp8 import FP8TensorMeta
BACKEND_F16m512_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16
def gemm( def gemm(
A: paddle.Tensor, A: paddle.Tensor,
...@@ -400,13 +403,30 @@ def rmsnorm_bwd( ...@@ -400,13 +403,30 @@ def rmsnorm_bwd(
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin) return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)
def mask_to_cu_seqlens(
mask: paddle.Tensor,
need_kv: bool = False,
) -> paddle.Tensor:
"""Convert mask to cu_seqlens"""
# mask shape: [b, 1, s_q, s_kv]
q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3]
q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
q_cu_seqlens[0] = 0
kv_cu_seqlens = None
if need_kv:
kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32)
kv_cu_seqlens[0] = 0
tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv)
return q_cu_seqlens, kv_cu_seqlens
def fused_attn_fwd_qkvpacked( def fused_attn_fwd_qkvpacked(
qkv: paddle.Tensor, qkv: paddle.Tensor,
cu_seqlens: paddle.Tensor, cu_seqlens: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool, is_training: bool,
max_seqlen: int, max_seqlen: int,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None, Bias: paddle.Tensor = None,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
...@@ -434,6 +454,18 @@ def fused_attn_fwd_qkvpacked( ...@@ -434,6 +454,18 @@ def fused_attn_fwd_qkvpacked(
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv." assert (Bias.dtype == qkv.dtype), "bias tensor must be in the same dtype as qkv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
if set_zero: if set_zero:
out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype)
else: else:
...@@ -444,6 +476,10 @@ def fused_attn_fwd_qkvpacked( ...@@ -444,6 +476,10 @@ def fused_attn_fwd_qkvpacked(
else: else:
softmax_aux = None softmax_aux = None
rng_state = paddle.empty(shape=[
2,
], dtype=paddle.int64)
# execute kernel # execute kernel
tex.te_fused_attn_fwd_qkvpacked( tex.te_fused_attn_fwd_qkvpacked(
qkv, qkv,
...@@ -464,9 +500,9 @@ def fused_attn_fwd_qkvpacked( ...@@ -464,9 +500,9 @@ def fused_attn_fwd_qkvpacked(
bias_type, bias_type,
attn_mask_type, attn_mask_type,
int(qkv_dtype), int(qkv_dtype),
rng_elts_per_thread,
) )
return out, softmax_aux, rng_state
return out, softmax_aux
def fused_attn_bwd_qkvpacked( def fused_attn_bwd_qkvpacked(
...@@ -476,6 +512,7 @@ def fused_attn_bwd_qkvpacked( ...@@ -476,6 +512,7 @@ def fused_attn_bwd_qkvpacked(
o: paddle.Tensor, o: paddle.Tensor,
d_o: paddle.Tensor, d_o: paddle.Tensor,
softmax_aux: paddle.Tensor, softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen: int, max_seqlen: int,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
attn_scale: float = None, attn_scale: float = None,
...@@ -498,6 +535,9 @@ def fused_attn_bwd_qkvpacked( ...@@ -498,6 +535,9 @@ def fused_attn_bwd_qkvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero: if set_zero:
dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype)
else: else:
...@@ -538,11 +578,11 @@ def fused_attn_fwd_kvpacked( ...@@ -538,11 +578,11 @@ def fused_attn_fwd_kvpacked(
kv: paddle.Tensor, kv: paddle.Tensor,
cu_seqlens_q: paddle.Tensor, cu_seqlens_q: paddle.Tensor,
cu_seqlens_kv: paddle.Tensor, cu_seqlens_kv: paddle.Tensor,
rng_state: paddle.Tensor,
is_training: bool, is_training: bool,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
Bias: paddle.Tensor = None, Bias: paddle.Tensor = None,
attn_scale: float = None, attn_scale: float = None,
dropout: float = 0.0, dropout: float = 0.0,
...@@ -573,6 +613,18 @@ def fused_attn_fwd_kvpacked( ...@@ -573,6 +613,18 @@ def fused_attn_fwd_kvpacked(
]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." ]), "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape."
assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv." assert (Bias.dtype == q.dtype), "bias tensor must be in the same dtype as q and kv."
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
# BF16/FP16 fused attention API from fmha_v1 apex
if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]:
rng_elts_per_thread = (max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA -
1) // BACKEND_F16m512_THREADS_PER_CTA
# BF16/FP16 fused attention API from fmha_v2
if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]:
rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS
if set_zero: if set_zero:
out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype)
else: else:
...@@ -583,6 +635,10 @@ def fused_attn_fwd_kvpacked( ...@@ -583,6 +635,10 @@ def fused_attn_fwd_kvpacked(
else: else:
softmax_aux = None softmax_aux = None
rng_state = paddle.empty(shape=[
2,
], dtype=paddle.int64)
# execute kernel # execute kernel
tex.te_fused_attn_fwd_kvpacked( tex.te_fused_attn_fwd_kvpacked(
q, q,
...@@ -607,9 +663,10 @@ def fused_attn_fwd_kvpacked( ...@@ -607,9 +663,10 @@ def fused_attn_fwd_kvpacked(
bias_type, bias_type,
attn_mask_type, attn_mask_type,
int(qkv_dtype), int(qkv_dtype),
rng_elts_per_thread,
) )
return out, softmax_aux return out, softmax_aux, rng_state
def fused_attn_bwd_kvpacked( def fused_attn_bwd_kvpacked(
...@@ -621,6 +678,7 @@ def fused_attn_bwd_kvpacked( ...@@ -621,6 +678,7 @@ def fused_attn_bwd_kvpacked(
o: paddle.Tensor, o: paddle.Tensor,
d_o: paddle.Tensor, d_o: paddle.Tensor,
softmax_aux: paddle.Tensor, softmax_aux: paddle.Tensor,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_kv: int, max_seqlen_kv: int,
qkv_dtype: tex.DType, qkv_dtype: tex.DType,
...@@ -647,6 +705,9 @@ def fused_attn_bwd_kvpacked( ...@@ -647,6 +705,9 @@ def fused_attn_bwd_kvpacked(
if attn_scale is None: if attn_scale is None:
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
assert (fused_attention_backend != FusedAttnBackend["No_Backend"]
), "Fused attention does not support this input combination."
if set_zero: if set_zero:
dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype)
dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype)
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <vector> #include <vector>
#include "paddle/extension.h" #include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h"
namespace transformer_engine { namespace transformer_engine {
namespace paddle_ext { namespace paddle_ext {
...@@ -122,6 +123,17 @@ inline DType Int2NvteDType(int64_t dtype) { ...@@ -122,6 +123,17 @@ inline DType Int2NvteDType(int64_t dtype) {
} }
} }
// get the fused attention backend
inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, max_seqlen_q, max_seqlen_kv, head_dim);
return fused_attention_backend;
}
// CUDA Utils // CUDA Utils
class cudaDevicePropertiesManager { class cudaDevicePropertiesManager {
public: public:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cub/cub.cuh>
#include <vector> #include <vector>
#include "../common.h" #include "../common.h"
#include "common.h" #include "common.h"
...@@ -542,6 +543,11 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl ...@@ -542,6 +543,11 @@ std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddl
return {dx, dgamma}; return {dx, dgamma};
} }
__global__ void set_rng_state(std::pair<uint64_t, uint64_t> seed_offset, int64_t *rng_state_ptr) {
rng_state_ptr[0] = static_cast<int64_t>(seed_offset.first);
rng_state_ptr[1] = static_cast<int64_t>(seed_offset.second);
}
void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens,
const paddle::optional<paddle::Tensor> &Bias, const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT paddle::Tensor &O, // NOLINT
...@@ -551,7 +557,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -551,7 +557,7 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
int64_t max_seqlen, bool is_training, float attn_scale, int64_t max_seqlen, bool is_training, float attn_scale,
float p_dropout, const std::string &qkv_layout, float p_dropout, const std::string &qkv_layout,
const std::string &bias_type, const std::string &attn_mask_type, const std::string &bias_type, const std::string &attn_mask_type,
const int64_t qkv_type) { const int64_t qkv_type, int64_t rng_elts_per_thread) {
if (is_training && !softmax_aux) { if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n"); NVTE_ERROR("softmax_aux must be provided when training. \n");
} }
...@@ -580,6 +586,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -580,6 +586,11 @@ void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
// extract random number generator seed and offset // extract random number generator seed and offset
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(QKV.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, QKV.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state); auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
...@@ -692,18 +703,16 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor ...@@ -692,18 +703,16 @@ void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
} }
void te_fused_attn_fwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV, void te_fused_attn_fwd_kvpacked(
const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &Q, const paddle::Tensor &KV, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &cu_seqlens_kv, const paddle::Tensor &cu_seqlens_kv, const paddle::optional<paddle::Tensor> &Bias,
const paddle::optional<paddle::Tensor> &Bias,
paddle::Tensor &O, // NOLINT paddle::Tensor &O, // NOLINT
paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT paddle::optional<paddle::Tensor> &softmax_aux, // NOLINT
paddle::Tensor &rng_state, // NOLINT paddle::Tensor &rng_state, // NOLINT
int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t total_seqs_kv,
int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, float attn_scale,
bool is_training, float attn_scale, float p_dropout, float p_dropout, const std::string &qkv_layout, const std::string &bias_type,
const std::string &qkv_layout, const std::string &bias_type, const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) {
const std::string &attn_mask_type, const int64_t qkv_type) {
if (is_training && !softmax_aux) { if (is_training && !softmax_aux) {
NVTE_ERROR("softmax_aux must be provided when training. \n"); NVTE_ERROR("softmax_aux must be provided when training. \n");
} }
...@@ -747,6 +756,10 @@ void te_fused_attn_fwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K ...@@ -747,6 +756,10 @@ void te_fused_attn_fwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &K
NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type);
NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type);
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place());
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread);
set_rng_state<<<1, 1, 0, Q.stream()>>>(seed_offset, static_cast<int64_t *>(rng_state.data()));
auto te_rng_state = MakeNvteTensor(rng_state); auto te_rng_state = MakeNvteTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
...@@ -1040,6 +1053,8 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_ ...@@ -1040,6 +1053,8 @@ __global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_
} }
} }
constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT paddle::Tensor &scale_inv, // NOLINT
...@@ -1057,7 +1072,6 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT ...@@ -1057,7 +1072,6 @@ void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
const auto rolled_amax_history = amax_history.roll({-1}, {0}); const auto rolled_amax_history = amax_history.roll({-1}, {0});
auto size = amax_history.numel(); auto size = amax_history.numel();
constexpr int BLOCK_SIZE = 256;
size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE; size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>( UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
amax.data<float>(), rolled_amax_history.data<float>(), amax_history.data<float>(), amax.data<float>(), rolled_amax_history.data<float>(), amax_history.data<float>(),
...@@ -1074,6 +1088,87 @@ void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT ...@@ -1074,6 +1088,87 @@ void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
amax.stream())); amax.stream()));
} }
__global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel(
const bool *mask, int32_t *q_actual_seqlen, int32_t *kv_actual_seqlen, int q_seqlen,
int kv_seqlen, bool need_kv) {
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage q_smem;
__shared__ typename BlockReduce::TempStorage kv_smem;
unsigned int tid = threadIdx.x;
unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen;
// load mask, convert to 1/0, do accumulation
int q = 0, kv = 0;
for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen;
q_idx += BLOCK_SIZE * kv_seqlen) {
q += (mask[q_idx + batch_offset] ? 0 : 1);
}
if (need_kv) {
for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) {
kv += (mask[kv_idx + batch_offset] ? 0 : 1);
}
}
__syncthreads();
// compute cub::BlockReduce
int q_sum, kv_sum;
q_sum = BlockReduce(q_smem).Sum(q);
if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv);
// write result for this block to global mem
if (tid == 0) {
q_actual_seqlen[blockIdx.x + 1] = q_sum;
if (need_kv) {
kv_actual_seqlen[blockIdx.x + 1] = kv_sum;
}
}
}
__global__ __launch_bounds__(BLOCK_SIZE) void block_prefix_sum_inplace(int32_t *x, int n) {
typedef cub::BlockScan<int32_t, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage smem;
// +1 to ignore the first element
int i = blockIdx.x * blockDim.x + threadIdx.x + 1;
// load data
int32_t thread_data[1];
thread_data[0] = i < n ? x[i] : 0;
__syncthreads();
// CUB block prefix sum
BlockScan(smem).InclusiveSum(thread_data, thread_data);
__syncthreads();
// write result
if (i < n) {
x[i] = thread_data[0];
}
}
void mask_to_cu_seqlens(const paddle::Tensor &mask,
paddle::Tensor &q_cu_seqlen, // NOLINT
paddle::optional<paddle::Tensor> &kv_cu_seqlen, // NOLINT
int q_seqlen, int kv_seqlen, bool need_kv) {
if (need_kv) {
NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr,
"kv_cu_seqlen must be provided when need_kv is true");
}
mask_to_actual_seqlens_kernel<<<mask.shape()[0], BLOCK_SIZE, 0, mask.stream()>>>(
mask.data<bool>(), q_cu_seqlen.data<int32_t>(),
reinterpret_cast<int32_t *>(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen,
need_kv);
// q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block
// to do prefix sum
NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail");
block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data<int32_t>(),
q_cu_seqlen.numel());
if (need_kv) {
block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(
reinterpret_cast<int32_t *>(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel());
}
}
} // namespace paddle_ext } // namespace paddle_ext
} // namespace transformer_engine } // namespace transformer_engine
...@@ -1192,7 +1287,8 @@ PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) ...@@ -1192,7 +1287,8 @@ PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
.Outputs({"O", paddle::Optional("softmax_aux")}) .Outputs({"O", paddle::Optional("softmax_aux")})
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"}, .SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}}) {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked));
...@@ -1214,7 +1310,8 @@ PD_BUILD_OP(te_fused_attn_fwd_kvpacked) ...@@ -1214,7 +1310,8 @@ PD_BUILD_OP(te_fused_attn_fwd_kvpacked)
.Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t",
"total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t",
"is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string",
"bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t"}) "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t",
"rng_elts_per_thread: int64_t"})
.SetInplaceMap({{"_O", "O"}, .SetInplaceMap({{"_O", "O"},
{paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}}) {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked));
...@@ -1287,3 +1384,11 @@ PD_BUILD_OP(update_latest_amax_history_inplace) ...@@ -1287,3 +1384,11 @@ PD_BUILD_OP(update_latest_amax_history_inplace)
.Outputs({"history"}) .Outputs({"history"})
.SetInplaceMap({{"_history", "history"}}) .SetInplaceMap({{"_history", "history"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace)); .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace));
PD_BUILD_OP(mask_to_cu_seqlens)
.Inputs({"mask", "_q_cu_seqlen", paddle::Optional("_kv_cu_seqlen")})
.Outputs({"q_cu_seqlen", paddle::Optional("kv_cu_seqlen")})
.Attrs({"q_seqlen: int", "kv_seqlen: int", "need_kv: bool"})
.SetInplaceMap({{"_q_cu_seqlen", "q_cu_seqlen"},
{paddle::Optional("_kv_cu_seqlen"), paddle::Optional("kv_cu_seqlen")}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::mask_to_cu_seqlens));
...@@ -14,6 +14,7 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); } ...@@ -14,6 +14,7 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); }
PYBIND11_MODULE(transformer_engine_paddle, m) { PYBIND11_MODULE(transformer_engine_paddle, m) {
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// Data structures // Data structures
py::enum_<DType>(m, "DType", py::module_local()) py::enum_<DType>(m, "DType", py::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
...@@ -23,6 +24,27 @@ PYBIND11_MODULE(transformer_engine_paddle, m) { ...@@ -23,6 +24,27 @@ PYBIND11_MODULE(transformer_engine_paddle, m) {
.value("kBFloat16", DType::kBFloat16) .value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3) .value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2); .value("kFloat8E5M2", DType::kFloat8E5M2);
py::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type")
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
py::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type")
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK);
py::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout")
.value("NVTE_NOT_INTERLEAVED", NVTE_QKV_Layout::NVTE_NOT_INTERLEAVED)
.value("NVTE_QKV_INTERLEAVED", NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
.value("NVTE_KV_INTERLEAVED", NVTE_QKV_Layout::NVTE_KV_INTERLEAVED);
py::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", py::module_local())
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8)
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend);
} }
} // namespace paddle_ext } // namespace paddle_ext
} // namespace transformer_engine } // namespace transformer_engine
...@@ -10,19 +10,22 @@ from typing import Optional, Tuple, Union ...@@ -10,19 +10,22 @@ from typing import Optional, Tuple, Union
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
from .softmax import FusedScaleMaskSoftmax from .softmax import FusedScaleMaskSoftmax
from ..constants import AttnTypes, TE_DType, dist_group_type from ..constants import (AttnTypes, TE_DType, QKVLayout, AttnBiasType, AttnMaskType,
FusedAttnBackend, dist_group_type)
from ..cpp_extensions import ( from ..cpp_extensions import (
fused_attn_fwd_qkvpacked, fused_attn_fwd_qkvpacked,
fused_attn_bwd_qkvpacked, fused_attn_bwd_qkvpacked,
fused_attn_fwd_kvpacked, fused_attn_fwd_kvpacked,
fused_attn_bwd_kvpacked, fused_attn_bwd_kvpacked,
mask_to_cu_seqlens,
) )
from ..distributed import get_tp_group_and_world_size, track_rng_state from ..distributed import get_tp_group_and_world_size, track_rng_state
from ..utils import attention_mask_func, divide, mask_to_cu_seqlens from ..utils import attention_mask_func, divide
from ..recompute import recompute from ..recompute import recompute
...@@ -30,16 +33,17 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): ...@@ -30,16 +33,17 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed QKV input""" """Function for FusedAttention with packed QKV input"""
@staticmethod @staticmethod
def forward(ctx, qkv, cu_seqlens, attn_bias, rng_state, max_seqlen, attn_scale, qkv_dtype, def forward(ctx, qkv, cu_seqlens, attn_bias, max_seqlen, attn_scale, qkv_dtype, dropout_p,
dropout_p, set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training): set_zero, qkv_layout, attn_bias_type, attn_mask_type, is_training,
fused_attention_backend):
"""Forward function for FusedAttention with packed QKV input""" """Forward function for FusedAttention with packed QKV input"""
out, aux_ctx_tensors = fused_attn_fwd_qkvpacked( out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
qkv, qkv,
cu_seqlens, cu_seqlens,
rng_state,
is_training, is_training,
max_seqlen, max_seqlen,
qkv_dtype, qkv_dtype,
fused_attention_backend,
attn_bias, attn_bias,
attn_scale, attn_scale,
dropout_p, dropout_p,
...@@ -49,7 +53,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): ...@@ -49,7 +53,7 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
attn_mask_type, attn_mask_type,
) )
ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, aux_ctx_tensors) ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux)
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
ctx.attn_scale = attn_scale ctx.attn_scale = attn_scale
...@@ -58,41 +62,41 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): ...@@ -58,41 +62,41 @@ class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out return out
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
"""Backward function for FusedAttention with packed QKV input""" """Backward function for FusedAttention with packed QKV input"""
qkv, out, cu_seqlens, rng_state, aux_ctx_tensors = ctx.saved_tensor() qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor()
dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, dqkv, *rest = fused_attn_bwd_qkvpacked(qkv, cu_seqlens, rng_state, out, d_out, softmax_aux,
aux_ctx_tensors, ctx.max_seqlen, ctx.qkv_dtype, ctx.fused_attention_backend, ctx.max_seqlen,
ctx.attn_scale, ctx.dropout_p, ctx.set_zero, ctx.qkv_dtype, ctx.attn_scale, ctx.dropout_p,
ctx.qkv_layout, ctx.attn_bias_type, ctx.set_zero, ctx.qkv_layout, ctx.attn_bias_type,
ctx.attn_mask_type) ctx.attn_mask_type)
# if no_bias, return dqkv # if no_bias, return dqkv
if ctx.attn_bias_type == "no_bias": if ctx.attn_bias_type == "no_bias":
return (dqkv, None, None) return (dqkv, None)
# else, return (dqkv, dbias) # else, return (dqkv, dbias)
return (dqkv, None, rest[0], None) return (dqkv, None, rest[0])
class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
"""Function for FusedAttention with packed KV input""" """Function for FusedAttention with packed KV input"""
@staticmethod @staticmethod
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, rng_state, max_seqlen_q, def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_kv, attn_bias, max_seqlen_q, max_seqlen_kv,
max_seqlen_kv, attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_scale, qkv_dtype, dropout_p, set_zero, qkv_layout, attn_bias_type,
attn_bias_type, attn_mask_type, is_training): attn_mask_type, is_training, fused_attention_backend):
"""Forward function for FusedAttention with packed KV input""" """Forward function for FusedAttention with packed KV input"""
out, aux_ctx_tensors = fused_attn_fwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, out, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
rng_state, is_training, max_seqlen_q, q, kv, cu_seqlens_q, cu_seqlens_kv, is_training, max_seqlen_q, max_seqlen_kv, qkv_dtype,
max_seqlen_kv, qkv_dtype, attn_bias, fused_attention_backend, attn_bias, attn_scale, dropout_p, set_zero, qkv_layout,
attn_scale, dropout_p, set_zero, qkv_layout,
attn_bias_type, attn_mask_type) attn_bias_type, attn_mask_type)
ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, aux_ctx_tensors) ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux)
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
...@@ -102,24 +106,26 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): ...@@ -102,24 +106,26 @@ class FusedAttnFuncPackedKV(paddle.autograd.PyLayer):
ctx.qkv_layout = qkv_layout ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.fused_attention_backend = fused_attention_backend
return out return out
@staticmethod @staticmethod
def backward(ctx, d_out): def backward(ctx, d_out):
"""Backward function for FusedAttention with packed KV input""" """Backward function for FusedAttention with packed KV input"""
q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, aux_ctx_tensors = ctx.saved_tensor() q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor()
dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out, dq, dkv, *rest = fused_attn_bwd_kvpacked(q, kv, cu_seqlens_q, cu_seqlens_kv, rng_state, out,
d_out, aux_ctx_tensors, ctx.max_seqlen_q, d_out, softmax_aux, ctx.fused_attention_backend,
ctx.max_seqlen_kv, ctx.qkv_dtype, ctx.attn_scale, ctx.max_seqlen_q, ctx.max_seqlen_kv, ctx.qkv_dtype,
ctx.dropout_p, ctx.set_zero, ctx.qkv_layout, ctx.attn_scale, ctx.dropout_p, ctx.set_zero,
ctx.attn_bias_type, ctx.attn_mask_type) ctx.qkv_layout, ctx.attn_bias_type,
ctx.attn_mask_type)
# if no_bias, return dq, dkv # if no_bias, return dq, dkv
if ctx.attn_bias_type == "no_bias": if ctx.attn_bias_type == "no_bias":
return (dq, dkv, None, None, None) return (dq, dkv, None, None)
# else, return (dq, dkv, dbias) # else, return (dq, dkv, dbias)
return (dq, dkv, None, None, rest[0], None) return (dq, dkv, None, None, rest[0])
class DotProductAttention(paddle.nn.Layer): class DotProductAttention(paddle.nn.Layer):
...@@ -160,18 +166,17 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -160,18 +166,17 @@ class DotProductAttention(paddle.nn.Layer):
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.attention_type = attention_type self.attention_type = attention_type
self.rng_state = paddle.zeros((2,), dtype='int64') self.qkv_layout = "qkv_interleaved" if attention_type == "self" else "kv_interleaved"
self.rng_state.persistable = True
self.backend = backend self.backend = backend
arch = paddle.device.cuda.get_device_capability() arch = paddle.device.cuda.get_device_capability()
self.is_fused_attn_supported = arch in ((8, 0), (9, 0)) self.is_fused_attn_supported = arch in ((8, 0), (9, 0))
self.enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", self.use_fused_attention = (int(os.getenv("NVTE_FUSED_ATTN", "1"))
"0")) and self.is_fused_attn_supported and self.is_fused_attn_supported)
if not self.enable_fused_attn and backend == 'transformer_engine': if not self.use_fused_attention and backend == 'transformer_engine':
# FMHA is not enabled, falling back to Paddle backend warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
self.backend = 'paddle' self.backend = 'paddle'
if self.backend != 'transformer_engine': if self.backend != 'transformer_engine':
...@@ -227,8 +232,25 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -227,8 +232,25 @@ class DotProductAttention(paddle.nn.Layer):
""" """
if self.backend == 'transformer_engine': if self.backend == 'transformer_engine':
max_s_q = query_layer.shape[1]
max_s_kv = max_s_q if self.attention_type == "self" else key_value_layer.shape[1]
self.fused_attention_backend = tex.get_fused_attn_backend(
TE_DType[query_layer.dtype], TE_DType[query_layer.dtype],
QKVLayout[self.qkv_layout], AttnBiasType[core_attention_bias_type],
AttnMaskType[self.attn_mask_type], self.attention_dropout, max_s_q, max_s_kv,
query_layer.shape[-1])
is_backend_avail = (self.fused_attention_backend in [
FusedAttnBackend["F16_max512_seqlen"], FusedAttnBackend["F16_arbitrary_seqlen"]
])
if is_backend_avail and self.use_fused_attention:
return self._te_forward(query_layer, key_value_layer, attention_mask, return self._te_forward(query_layer, key_value_layer, attention_mask,
core_attention_bias_type, core_attention_bias, set_zero) core_attention_bias_type, core_attention_bias, set_zero)
warnings.warn("Fused attention is not enabled, falling back to Paddle backend")
self.backend = 'paddle'
self.scale_mask_softmax = FusedScaleMaskSoftmax(self.attn_mask_type,
attention_mask_func,
backend=self.backend)
if self.backend == 'paddle': if self.backend == 'paddle':
if core_attention_bias_type != "no_bias": if core_attention_bias_type != "no_bias":
warnings.warn("Paddle backend dot product attention does not support bias yet. " warnings.warn("Paddle backend dot product attention does not support bias yet. "
...@@ -246,33 +268,26 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -246,33 +268,26 @@ class DotProductAttention(paddle.nn.Layer):
set_zero: bool = True, set_zero: bool = True,
) -> paddle.Tensor: ) -> paddle.Tensor:
gen_state = paddle.get_rng_state()[0].__getstate__()
self.rng_state[0], self.rng_state[1] = gen_state[1], gen_state[2] # [seed, offset]
if self.attention_type == "self": if self.attention_type == "self":
# self attention - q: [b, s, 3, h, d] kv: None # self attention - q: [b, s, 3, h, d] kv: None
assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3 assert (len(query_layer.shape) == 5 and query_layer.shape[2] == 3
and key_value_layer is None and key_value_layer is None
), "query shape must be [b, s, 3, h, d] for dot product self attention" ), "query shape must be [b, s, 3, h, d] for dot product self attention"
max_seqlen = query_layer.shape[1] max_seqlen = query_layer.shape[1]
cu_seqlens, _ = mask_to_cu_seqlens(attention_mask) if self.attn_mask_type == "causal" or attention_mask is None:
cu_seqlens = paddle.arange(0, (query_layer.shape[0] + 1) * query_layer.shape[1],
step=query_layer.shape[1],
dtype='int32')
else:
cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False)
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
qkv_layout = "qkv_interleaved"
output = FusedAttnFuncPackedQKV.apply( output = FusedAttnFuncPackedQKV.apply(query_layer, cu_seqlens, core_attention_bias,
query_layer, max_seqlen, 1.0 / self.norm_factor, qkv_dtype,
cu_seqlens,
core_attention_bias,
self.rng_state,
max_seqlen,
1.0 / self.norm_factor,
qkv_dtype,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
set_zero, set_zero, self.qkv_layout,
qkv_layout, core_attention_bias_type, self.attn_mask_type,
core_attention_bias_type, self.training, self.fused_attention_backend)
self.attn_mask_type,
self.training,
)
elif self.attention_type == "cross": elif self.attention_type == "cross":
# cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d] # cross attention - q: [b, s_q, h, d] kv: [b, s_kv, 2, h, d]
assert ( assert (
...@@ -280,29 +295,19 @@ class DotProductAttention(paddle.nn.Layer): ...@@ -280,29 +295,19 @@ class DotProductAttention(paddle.nn.Layer):
and key_value_layer.shape[2] == 2 and key_value_layer.shape[2] == 2
), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \ ), "query shape must be [b, s, h, d] and key shape must be [b, s, 2, h, d]" \
"for dot product cross attention" "for dot product cross attention"
assert (attention_mask
is not None), "attention_mask must be provided for cross attention"
max_seqlen_q = query_layer.shape[1] max_seqlen_q = query_layer.shape[1]
max_seqlen_kv = key_value_layer.shape[1] max_seqlen_kv = key_value_layer.shape[1]
cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True)
qkv_dtype = TE_DType[query_layer.dtype] qkv_dtype = TE_DType[query_layer.dtype]
qkv_layout = "kv_interleaved" output = FusedAttnFuncPackedKV.apply(query_layer, key_value_layer, cu_seqlens_q,
output = FusedAttnFuncPackedKV.apply( cu_seqlens_kv, core_attention_bias, max_seqlen_q,
query_layer, max_seqlen_kv, 1.0 / self.norm_factor, qkv_dtype,
key_value_layer,
cu_seqlens_q,
cu_seqlens_kv,
core_attention_bias,
self.rng_state,
max_seqlen_q,
max_seqlen_kv,
1.0 / self.norm_factor,
qkv_dtype,
self.attention_dropout if self.training else 0.0, self.attention_dropout if self.training else 0.0,
set_zero, set_zero, self.qkv_layout,
qkv_layout, core_attention_bias_type, self.attn_mask_type,
core_attention_bias_type, self.training, self.fused_attention_backend)
self.attn_mask_type,
self.training,
)
else: else:
raise ValueError("attention_type must be one of ['self', 'cross']") raise ValueError("attention_type must be one of ['self', 'cross']")
return output return output
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from typing import Optional, Union from typing import Optional, Union
import paddle import paddle
from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd
from . import LayerNormMLP, LayerNorm, MultiHeadAttention from . import LayerNormMLP, LayerNorm, MultiHeadAttention
from ..constants import AttnMaskTypes, LayerTypes, dist_group_type from ..constants import AttnMaskTypes, LayerTypes, dist_group_type
...@@ -182,6 +183,11 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -182,6 +183,11 @@ class TransformerLayer(paddle.nn.Layer):
backend=backend, backend=backend,
) )
self.fused_dropout_add1 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train")
if self.layer_type == "decoder":
self.fused_dropout_add2 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train")
self.fused_dropout_add3 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train")
def forward( def forward(
self, self,
hidden_states: paddle.Tensor, hidden_states: paddle.Tensor,
...@@ -249,12 +255,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -249,12 +255,7 @@ class TransformerLayer(paddle.nn.Layer):
# dropoout add. # dropoout add.
with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout( bda_output = self.fused_dropout_add1(attention_output, residual)
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out
# Cross attention. # Cross attention.
if self.layer_type == "decoder": if self.layer_type == "decoder":
...@@ -275,12 +276,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -275,12 +276,7 @@ class TransformerLayer(paddle.nn.Layer):
with track_rng_state(enable=self.tensor_parallel, with track_rng_state(enable=self.tensor_parallel,
name=self.hidden_dropout_rng_state_name): name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout( bda_output = self.fused_dropout_add2(attention_output, residual)
attention_output,
p=self.hidden_dropout,
training=True,
)
bda_output = residual + out
# MLP. # MLP.
mlp_outputs = self.layernorm_mlp(bda_output) mlp_outputs = self.layernorm_mlp(bda_output)
...@@ -292,8 +288,7 @@ class TransformerLayer(paddle.nn.Layer): ...@@ -292,8 +288,7 @@ class TransformerLayer(paddle.nn.Layer):
# dropoout add. # dropoout add.
with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name):
out = paddle.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=True) output = self.fused_dropout_add3(mlp_output, residual)
output = residual + out
# For BERT like architectures. # For BERT like architectures.
if self.output_layernorm: if self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment