Unverified Commit 5c58beaa authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add long sequence support for fused attention (#237)



* add long sequence support and unify three backends for fused attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to v0.9.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cpu_float2half_rn with __float2half_rn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix backend selection and NVTEDType
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make cudnn plan caches thread_local
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix replacement of cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout probablity to 0 in inference mode
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change negInfinity to be consistent with m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove float2half conversion for scale_dropout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back runtime api for sm detection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gemm3 to enums FP8Fwd/BwdTensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change dropout from no to yes for fmha_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state in m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix elts_per_thread calculation in kvpacked fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove dropout=0.0 restriction for m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state completely from m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4330e025
Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609
Subproject commit a4f05c1edcef453f5fd52f96218c29c7d420e511
This diff is collapsed.
......@@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
activation/swiglu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp
fused_attn/utils.cu
......
......@@ -7,8 +7,80 @@
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include "fused_attn_f16_max512_seqlen.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ >= 80)
&& (head_dim == 64)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
flag_m512 = true;
}
if ((sm_arch_ >= 80)
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
&& (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_m512 == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
} else if ((flag_m512 == false) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
}
const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND");
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)
&& (flag_arb == true)
&& (env_backend != nullptr)
&& (std::string(env_backend) == std::to_string(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
return backend;
}
// NVTE fused attention FWD FP8 with packed QKV
void nvte_fused_attn_fwd_qkvpacked(
......@@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
......@@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype;
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type
fused_attn_fwd_fp8_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O,
Aux_Output_Tensors,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b,
max_seqlen,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention BWD FP8 with packed QKV
......@@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype;
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state,
wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
// Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked(
fused_attn_fp8_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO,
......@@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked(
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_qkvpacked(
b,
max_seqlen,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_dO,
Aux_CTX_Tensors,
output_dQKV,
output_dBias,
input_cu_seqlens,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention FWD FP8 with packed KV
......@@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
......@@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype;
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens_q,
input_cu_seqlens_kv,
input_rng_state,
wkspace,
stream,
handle);
fused_attn_max_512_fwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
// NVTE fused attention BWD FP8 with packed KV
......@@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype;
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(
Q_type, KV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_kvpacked(
b,
max_seqlen_q,
max_seqlen_kv,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_dO,
Aux_CTX_Tensors,
output_dQ,
output_dKV,
output_dBias,
input_cu_seqlens_q,
input_cu_seqlens_kv,
wkspace,
stream,
handle);
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked(
b, max_seqlen_q, max_seqlen_kv, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_dO,
output_S,
output_dQ, output_dKV, output_dBias,
input_cu_seqlens_q, input_cu_seqlens_kv,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n");
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file fused_attn_arbitrary_seqlen.h
* \brief Functions for fused attention with seqlen > 512
*/
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "fused_attn_fp16_bf16_max_seqlen_512.h"
#include "fused_attn_f16_max512_seqlen.h"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
......@@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked(
void *devPtrS = nullptr;
if (Aux_Output_Tensors->size == 0) {
Aux_Output_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_Output_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
......@@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_Output_Tensors->size == 0) {
Aux_Output_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_Output_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]);
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
......@@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1);
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
......@@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
......@@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1);
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
void *devPtrS = output_S->data.dptr;
// devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS;
......
......@@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_Output_Tensors,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......@@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
......@@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......
......@@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM(
}
// fused attention FWD FP8
void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
......@@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
}
// fused attention BWD FP8
void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv,
......@@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
......@@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
......@@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked(
void* devPtrM = nullptr;
void* devPtrZInv = nullptr;
if (Aux_Output_Tensors->size == 0) {
if (Aux_CTX_Tensors->size == 0) {
if (is_training) {
Aux_Output_Tensors->size = 2;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]);
Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_Output_Tensors->size == 2) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]);
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
}
void* devPtrAmaxS = input_output_S->amax.dptr;
......@@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn::fa_fwd_fp8(
fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen, max_seqlen, h, d,
is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
......@@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked(
}
}
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......@@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0;
fused_attn::fa_bwd_fp8(
fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen, max_seqlen, h, d,
attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV,
......
......@@ -13,7 +13,7 @@
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked(
void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
bool is_training, float attn_scale,
......@@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV,
Tensor *input_output_S,
Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens,
const Tensor *rng_state,
Tensor *workspace,
......@@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked(
cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked(
void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen,
size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......
......@@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
}
}
} // namespace fused_attn
// get cuDNN data type
......
......@@ -94,6 +94,38 @@ enum NVTE_Mask_Type {
NVTE_CAUSAL_MASK = 2,
};
enum NVTE_Fused_Attn_Backend {
/*! No supported backend */
NVTE_No_Backend = -1,
/*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */
NVTE_F16_max512_seqlen = 0,
/*! cuDNN-based FP16/BF16 fused attention for any sequence length */
NVTE_F16_arbitrary_seqlen = 1,
/*! cuDNN-based FP8 fused attention for <= 512 sequence length */
NVTE_FP8 = 2,
};
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
/*! \brief Compute dot product attention with packed QKV input.
*
* Computes:
......@@ -104,36 +136,38 @@ enum NVTE_Mask_Type {
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing.
* It may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV,
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens,
const NVTETensor rng_state,
size_t max_seqlen,
......@@ -147,30 +181,32 @@ void nvte_fused_attn_fwd_qkvpacked(
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing.
* It may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim].
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV,
......@@ -199,31 +235,32 @@ void nvte_fused_attn_bwd_qkvpacked(
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing
* for Q. It may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing
* for KV. It may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator.
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q,
......@@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias,
NVTETensor S,
NVTETensor O,
NVTETensorPack* Aux_Output_Tensors,
NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state,
......@@ -246,33 +283,34 @@ void nvte_fused_attn_fwd_kvpacked(
*
* Support Matrix:
\verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim
*
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing
* for Q. It may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing
* for KV. It may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* e.g. M, ZInv, rng_state.
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing for Q.
* it may be >= max(cu_seqlens_q).
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q,
......
This diff is collapsed.
......@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
AttnMaskTypes = ("causal", "padding")
AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross")
......
......@@ -58,7 +58,10 @@ enum FP8FwdTensors {
GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5
GEMM2_OUTPUT = 5,
GEMM3_INPUT = 6,
GEMM3_WEIGHT = 7,
GEMM3_OUTPUT = 8
};
// Used as named indices on the `scale`, `scale_inv`,
......@@ -67,7 +70,9 @@ enum FP8BwdTensors {
GRAD_OUTPUT1 = 0,
GRAD_INPUT1 = 1,
GRAD_OUTPUT2 = 2,
GRAD_INPUT2 = 3
GRAD_INPUT2 = 3,
GRAD_OUTPUT3 = 4,
GRAD_INPUT3 = 5
};
......@@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt32:
return torch::kInt32;
case transformer_engine::DType::kInt64:
return torch::kInt64;
case transformer_engine::DType::kFloat32:
return at::kFloat;
case transformer_engine::DType::kFloat16:
......
......@@ -7,17 +7,22 @@
#include "common.h"
#include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout);
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type);
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type);
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const transformer_engine::DType qkv_type,
......@@ -27,13 +32,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens,
const at::Tensor QKV,
const at::Tensor O,
......@@ -53,9 +61,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
bool is_training, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, bool is_training,
float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
......@@ -67,14 +77,17 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen);
const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d,
float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type,
size_t h, size_t d, float attn_scale,
float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv,
const at::Tensor Q,
......
......@@ -403,6 +403,9 @@ class TransformerLayer(torch.nn.Module):
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor:
"""
Transformer Layer: attention block and a feedforward network (MLP)
......@@ -445,6 +448,12 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
"""
hidden_states = hidden_states.contiguous()
......@@ -473,6 +482,9 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
......@@ -516,6 +528,9 @@ class TransformerLayer(torch.nn.Module):
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
)
if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment