Unverified Commit 5c58beaa authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Add long sequence support for fused attention (#237)



* add long sequence support and unify three backends for fused attention
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update cudnn-frontend to v0.9.1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cpu_float2half_rn with __float2half_rn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix backend selection and NVTEDType
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make cudnn plan caches thread_local
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* replace cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix replacement of cuDNN throw with NVTE_CHECK
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* force dropout probablity to 0 in inference mode
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change negInfinity to be consistent with m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove float2half conversion for scale_dropout
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back runtime api for sm detection
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add gemm3 to enums FP8Fwd/BwdTensors
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* change dropout from no to yes for fmha_v1
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state in m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix elts_per_thread calculation in kvpacked fwd
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove dropout=0.0 restriction for m512 fused attn
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove output_rng_state completely from m512 kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4330e025
Subproject commit e7f64390e9bb4a3db622ffe11c973834f572b609 Subproject commit a4f05c1edcef453f5fd52f96218c29c7d420e511
This diff is collapsed.
...@@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES ...@@ -12,9 +12,10 @@ list(APPEND transformer_engine_SOURCES
transpose/transpose_fusion.cu transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/multi_cast_transpose.cu
activation/gelu.cu activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
fused_attn/fused_attn_fp16_bf16_max_seqlen_512.cu
fused_attn/fused_attn_fp8.cu fused_attn/fused_attn_fp8.cu
fused_attn/fused_attn.cpp fused_attn/fused_attn.cpp
fused_attn/utils.cu fused_attn/utils.cu
......
...@@ -7,8 +7,80 @@ ...@@ -7,8 +7,80 @@
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "../common.h" #include "../common.h"
#include "utils.h" #include "utils.h"
#include "fused_attn_fp16_bf16_max_seqlen_512.h" #include "fused_attn_f16_max512_seqlen.h"
#include "fused_attn_f16_arbitrary_seqlen.h"
#include "fused_attn_fp8.h" #include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
if ((sm_arch_ >= 80)
&& (head_dim == 64)
&& ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
|| (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
&& ((qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)
|| (qkv_layout == NVTE_QKV_Layout::NVTE_KV_INTERLEAVED))) {
flag_m512 = true;
}
if ((sm_arch_ >= 80)
&& (max_seqlen_q == max_seqlen_kv)
&& ((head_dim == 64) || (head_dim == 128))
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_QKV_INTERLEAVED)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
&& (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
if (flag_m512 == true) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen;
} else if ((flag_m512 == false) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
}
const char* env_backend = std::getenv("NVTE_FUSED_ATTN_BACKEND");
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)
&& (flag_arb == true)
&& (env_backend != nullptr)
&& (std::string(env_backend) == std::to_string(
NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
return backend;
}
// NVTE fused attention FWD FP8 with packed QKV // NVTE fused attention FWD FP8 with packed QKV
void nvte_fused_attn_fwd_qkvpacked( void nvte_fused_attn_fwd_qkvpacked(
...@@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -16,7 +88,7 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
...@@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -43,54 +115,56 @@ void nvte_fused_attn_fwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen <= 512)) { nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// FP8 API doesn't use input_Bias, bias_type or attn_mask_type fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_fwd_fp8_qkvpacked( b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_Bias, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
is_training, attn_scale, dropout, qkv_layout, is_training, attn_scale, dropout, qkv_layout,
input_QKV, input_output_S, output_O, input_QKV, input_output_S, output_O,
Aux_Output_Tensors, Aux_CTX_Tensors,
input_cu_seqlens, input_cu_seqlens,
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif #endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_qkvpacked(
b,
max_seqlen,
h,
d,
is_training,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens,
input_rng_state,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention BWD FP8 with packed QKV // NVTE fused attention BWD FP8 with packed QKV
...@@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -130,18 +204,52 @@ void nvte_fused_attn_bwd_qkvpacked(
size_t d = input_QKV->data.shape[ndim - 1]; size_t d = input_QKV->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_QKV->data.dtype; const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen <= 512)) { nvte_get_fused_attn_backend(
QKV_type, QKV_type,
qkv_layout, bias_type, attn_mask_type,
dropout, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
output_S,
output_dQKV, output_dBias,
input_cu_seqlens, input_rng_state,
wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
"with arbitrary sequence length. \n";
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// Aux_CTX_Tensors contain [M, ZInv, rng_state] generated by the forward pass
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(
// FP8 API doesn't use input_dBias, bias_type or attn_mask_type
fused_attn_bwd_fp8_qkvpacked(
b, max_seqlen, h, d, b, max_seqlen, h, d,
attn_scale, dropout, qkv_layout, attn_scale, dropout, qkv_layout,
input_QKV, input_O, input_dO, input_QKV, input_O, input_dO,
...@@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -152,38 +260,10 @@ void nvte_fused_attn_bwd_qkvpacked(
input_rng_state, input_rng_state,
wkspace, stream, handle); wkspace, stream, handle);
#else #else
NVTE_ERROR("cuDNN 8.9 is required to run FP8 fused attention. \n"); NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16))
&& (max_seqlen <= 512)) {
#if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_qkvpacked(
b,
max_seqlen,
h,
d,
attn_scale,
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_QKV,
input_dO,
Aux_CTX_Tensors,
output_dQKV,
output_dBias,
input_cu_seqlens,
wkspace,
stream,
handle);
#else
NVTE_ERROR(
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if (max_seqlen > 512) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention FWD FP8 with packed KV // NVTE fused attention FWD FP8 with packed KV
...@@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -193,7 +273,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state, const NVTETensor rng_state,
...@@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -223,45 +303,37 @@ void nvte_fused_attn_fwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { nvte_get_fused_attn_backend(
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); Q_type, KV_type,
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) qkv_layout, bias_type, attn_mask_type,
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_fwd_kvpacked( fused_attn_max_512_fwd_kvpacked(
b, b, max_seqlen_q, max_seqlen_kv, h, d,
max_seqlen_q, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
max_seqlen_kv, input_Q, input_KV, input_Bias, output_O,
h, Aux_CTX_Tensors,
d, input_cu_seqlens_q, input_cu_seqlens_kv,
is_training, input_rng_state,
attn_scale, wkspace, stream, handle);
dropout,
qkv_layout,
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_Bias,
output_O,
Aux_Output_Tensors,
input_cu_seqlens_q,
input_cu_seqlens_kv,
input_rng_state,
wkspace,
stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
// NVTE fused attention BWD FP8 with packed KV // NVTE fused attention BWD FP8 with packed KV
...@@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -307,44 +379,37 @@ void nvte_fused_attn_bwd_kvpacked(
size_t d = input_Q->data.shape[ndim - 1]; size_t d = input_Q->data.shape[ndim - 1];
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const DType QKV_type = input_Q->data.dtype; const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
if (((QKV_type == DType::kFloat8E4M3) || (QKV_type == DType::kFloat8E5M2)) NVTE_Fused_Attn_Backend fused_attention_backend =
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { nvte_get_fused_attn_backend(
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n"); Q_type, KV_type,
} else if (((QKV_type == DType::kFloat16) || (QKV_type == DType::kBFloat16)) qkv_layout, bias_type, attn_mask_type,
&& (max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { dropout, max_seqlen_q, max_seqlen_kv, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
fused_attn_max_512_bwd_kvpacked( Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
b, fused_attn_max_512_bwd_kvpacked(
max_seqlen_q, b, max_seqlen_q, max_seqlen_kv, h, d,
max_seqlen_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
h, input_Q, input_KV, input_dO,
d, output_S,
attn_scale, output_dQ, output_dKV, output_dBias,
dropout, input_cu_seqlens_q, input_cu_seqlens_kv,
qkv_layout, wkspace, stream, handle);
bias_type,
attn_mask_type,
input_Q,
input_KV,
input_dO,
Aux_CTX_Tensors,
output_dQ,
output_dKV,
output_dBias,
input_cu_seqlens_q,
input_cu_seqlens_kv,
wkspace,
stream,
handle);
#else #else
NVTE_ERROR( NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
"cuDNN 8.9.1 is required to run BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif #endif
} else if ((max_seqlen_q > 512) || (max_seqlen_kv > 512)) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
NVTE_ERROR("TBD: No support for fused attention with >512 seqlence length currently. \n"); const char* err_msg =
"The FP16/BF16 fused attention (arbitrary seqlen) currently "
"only supports packed QKV input.\n";
NVTE_ERROR(err_msg);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
} else { } else {
NVTE_ERROR("Invalid combination of data type and sequence length! \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file fused_attn_arbitrary_seqlen.h
* \brief Functions for fused attention with seqlen > 512
*/
#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
#include "transformer_engine/fused_attn.h"
#include <cudnn.h>
#include "common/common.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_size, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t num_head,
size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_O,
const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "fused_attn_fp16_bf16_max_seqlen_512.h" #include "fused_attn_f16_max512_seqlen.h"
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1239,7 +1239,7 @@ void fused_attn_max_512_fwd_qkvpacked(
size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training, size_t batch, size_t max_seqlen, size_t num_head, size_t head_dim, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1260,14 +1260,14 @@ void fused_attn_max_512_fwd_qkvpacked(
void *devPtrS = nullptr; void *devPtrS = nullptr;
if (Aux_Output_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_Output_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype; output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_Output_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} }
...@@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1307,7 +1307,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1336,14 +1336,14 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
const DType kv_type = input_KV->data.dtype; const DType kv_type = input_KV->data.dtype;
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
if (Aux_Output_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_Output_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type; output_S->data.dtype = q_type;
} else if (Aux_Output_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_Output_Tensors->tensors[0]); Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} }
...@@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1381,7 +1381,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace, const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) { cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -1408,12 +1408,8 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
void *devPtrdBias = output_dBias->data.dptr; void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1); void *devPtrS = output_S->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
// devPtrdS reuses the memory of devPtrS // devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS; void *devPtrdS = devPtrS;
...@@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1446,7 +1442,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
...@@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -1472,12 +1468,8 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
void *devPtrdBias = output_dBias->data.dptr; void *devPtrdBias = output_dBias->data.dptr;
NVTE_CHECK(Aux_CTX_Tensors->size == 1); void *devPtrS = output_S->data.dptr;
void *devPtrS = nullptr;
if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
}
// devPtrdS reuses the memory of devPtrS // devPtrdS reuses the memory of devPtrS
void *devPtrdS = devPtrS; void *devPtrdS = devPtrS;
......
...@@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -24,7 +24,7 @@ void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_QKV, const Tensor *input_Bias,
Tensor *output_O, NVTETensorPack *Aux_Output_Tensors, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *rng_state, const Tensor *cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -34,7 +34,7 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_Bias, Tensor *output_O, const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_Output_Tensors, const Tensor *q_cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
const Tensor *kv_cu_seqlens, const Tensor *rng_state, const Tensor *kv_cu_seqlens, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
...@@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu ...@@ -42,7 +42,7 @@ void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t max_seqlen, size_t nu
size_t head_dim, float attn_scale, float p_dropout, size_t head_dim, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, const Tensor *input_QKV, NVTE_Mask_Type mask_type, const Tensor *input_QKV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dQKV, Tensor *output_dBias,
const Tensor *cu_seqlens, Tensor *workspace, const Tensor *cu_seqlens, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle); cudaStream_t stream, cudnnHandle_t handle);
...@@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k ...@@ -52,7 +52,7 @@ void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t q_max_seqlen, size_t k
float p_dropout, NVTE_QKV_Layout qkv_layout, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Q, const Tensor *input_KV,
const Tensor *input_dO, const NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_dO, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
......
...@@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM( ...@@ -991,7 +991,7 @@ static cudnn_frontend::Tensor createdSQBMM(
} }
// fused attention FWD FP8 // fused attention FWD FP8
void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, void fused_attn_fp8_fwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
bool isTraining, float attnScale, bool isTraining, float attnScale,
float dropoutProbability, NVTE_QKV_Layout layout, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
...@@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1303,7 +1303,7 @@ void fa_fwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
} }
// fused attention BWD FP8 // fused attention BWD FP8
void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, void fused_attn_fp8_bwd_impl(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
float attnScale, float dropoutProbability, NVTE_QKV_Layout layout, float attnScale, float dropoutProbability, NVTE_QKV_Layout layout,
void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrQ, void* devPtrK, void* devPtrV,
void* devPtrM, void* devPtrZInv, void* devPtrM, void* devPtrZInv,
...@@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d, ...@@ -1858,7 +1858,7 @@ void fa_bwd_fp8(int64_t b, int64_t s_q, int64_t s_kv, int64_t h, int64_t d,
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
...@@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1866,7 +1866,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *cu_seqlens,
const Tensor *rng_state, const Tensor *rng_state,
Tensor *workspace, Tensor *workspace,
...@@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1888,23 +1888,29 @@ void fused_attn_fwd_fp8_qkvpacked(
void* devPtrM = nullptr; void* devPtrM = nullptr;
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_Output_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
if (is_training) { if (is_training) {
Aux_Output_Tensors->size = 2; Aux_CTX_Tensors->size = 3;
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {b, h, max_seqlen, 1}; output_M->data.shape = {b, h, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
output_ZInv->data.dptr = nullptr; output_ZInv->data.dptr = nullptr;
output_ZInv->data.shape = {b, h, max_seqlen, 1}; output_ZInv->data.shape = {b, h, max_seqlen, 1};
output_ZInv->data.dtype = DType::kFloat32; output_ZInv->data.dtype = DType::kFloat32;
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_Output_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_M = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[0]); Tensor *output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_Output_Tensors->tensors[1]); Tensor *output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
} }
void* devPtrAmaxS = input_output_S->amax.dptr; void* devPtrAmaxS = input_output_S->amax.dptr;
...@@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1921,7 +1927,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fa_fwd_fp8( fused_attn::fused_attn_fp8_fwd_impl(
b, max_seqlen, max_seqlen, h, d, b, max_seqlen, max_seqlen, h, d,
is_training, attn_scale, p_dropout, qkv_layout, is_training, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
...@@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -1948,7 +1954,7 @@ void fused_attn_fwd_fp8_qkvpacked(
} }
} }
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
...@@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked( ...@@ -2011,7 +2017,7 @@ void fused_attn_bwd_fp8_qkvpacked(
const DType QKV_type = input_QKV->data.dtype; const DType QKV_type = input_QKV->data.dtype;
size_t workspace_size = 0; size_t workspace_size = 0;
fused_attn::fa_bwd_fp8( fused_attn::fused_attn_fp8_bwd_impl(
b, max_seqlen, max_seqlen, h, d, b, max_seqlen, max_seqlen, h, d,
attn_scale, p_dropout, qkv_layout, attn_scale, p_dropout, qkv_layout,
devPtrQ, devPtrK, devPtrV, devPtrQ, devPtrK, devPtrV,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace transformer_engine { namespace transformer_engine {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
// fused attention FWD FP8 with packed QKV // fused attention FWD FP8 with packed QKV
void fused_attn_fwd_fp8_qkvpacked( void fused_attn_fp8_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
bool is_training, float attn_scale, bool is_training, float attn_scale,
...@@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -21,7 +21,7 @@ void fused_attn_fwd_fp8_qkvpacked(
const Tensor *input_QKV, const Tensor *input_QKV,
Tensor *input_output_S, Tensor *input_output_S,
Tensor *output_O, Tensor *output_O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const Tensor *cu_seqlens, const Tensor *cu_seqlens,
const Tensor *rng_state, const Tensor *rng_state,
Tensor *workspace, Tensor *workspace,
...@@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked( ...@@ -29,7 +29,7 @@ void fused_attn_fwd_fp8_qkvpacked(
cudnnHandle_t handle); cudnnHandle_t handle);
// fused attention BWD FP8 with packed QKV // fused attention BWD FP8 with packed QKV
void fused_attn_bwd_fp8_qkvpacked( void fused_attn_fp8_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t b, size_t max_seqlen,
size_t h, size_t d, size_t h, size_t d,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
......
...@@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, ...@@ -249,7 +249,6 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b,
kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid];
} }
} }
} // namespace fused_attn } // namespace fused_attn
// get cuDNN data type // get cuDNN data type
......
...@@ -94,6 +94,38 @@ enum NVTE_Mask_Type { ...@@ -94,6 +94,38 @@ enum NVTE_Mask_Type {
NVTE_CAUSAL_MASK = 2, NVTE_CAUSAL_MASK = 2,
}; };
enum NVTE_Fused_Attn_Backend {
/*! No supported backend */
NVTE_No_Backend = -1,
/*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */
NVTE_F16_max512_seqlen = 0,
/*! cuDNN-based FP16/BF16 fused attention for any sequence length */
NVTE_F16_arbitrary_seqlen = 1,
/*! cuDNN-based FP8 fused attention for <= 512 sequence length */
NVTE_FP8 = 2,
};
/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] dropout The dropout probability.
* \param[in] max_seqlen_q The sequence length of Q.
* \param[in] max_seqlen_kv The sequence length of K, V.
* \param[in] head_dim The head dimension of Q, K, V.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype,
NVTEDType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
* Computes: * Computes:
...@@ -104,36 +136,38 @@ enum NVTE_Mask_Type { ...@@ -104,36 +136,38 @@ enum NVTE_Mask_Type {
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim]. * [total_seqs, 3, num_heads, head_dim].
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * e.g. M, ZInv, rng_state.
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing. * \param[in] rng_state Seed and offset of CUDA random number generator.
* It may be >= max(cu_seqlens). * \param[in] max_seqlen Max sequence length used for computing,
* \param[in] is_training Whether this is in training mode or inference. * it may be >= max(cu_seqlens).
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] dropout Dropout probability. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] dropout Dropout probability.
* \param[in] bias_type Bias type. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] attn_mask_type Attention mask type. * \param[in] bias_type Bias type.
* \param[in] workspace Workspace tensor. * \param[in] attn_mask_type Attention mask type.
* \param[in] stream CUDA stream used for this operation. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked( void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor QKV,
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens,
const NVTETensor rng_state, const NVTETensor rng_state,
size_t max_seqlen, size_t max_seqlen,
...@@ -147,30 +181,32 @@ void nvte_fused_attn_fwd_qkvpacked( ...@@ -147,30 +181,32 @@ void nvte_fused_attn_fwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 1 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS | CAUSAL | Yes | > 512 | 64, 128 |
| 2 | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] QKV The QKV tensor in packed format, * \param[in] QKV The QKV tensor in packed format,
* [total_seqs, 3, num_heads, head_dim]. * [total_seqs, 3, num_heads, head_dim].
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* \param[out] dQKV The gradient of the QKV tensor. * e.g. M, ZInv, rng_state.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dQKV The gradient of the QKV tensor.
* \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1]. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] max_seqlen Max sequence length used for computing. * \param[in] cu_seqlens Accumulative sequence lengths, [batch_size + 1].
* It may be >= max(cu_seqlens). * \param[in] max_seqlen Max sequence length used for computing,
* \param[in] attn_scale Scaling factor for Q * K.T. * it may be >= max(cu_seqlens).
* \param[in] dropout Dropout probability. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] dropout Dropout probability.
* \param[in] bias_type Bias type. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] attn_mask_type Attention mask type. * \param[in] bias_type Bias type.
* \param[in] workspace Workspace tensor. * \param[in] attn_mask_type Attention mask type.
* \param[in] stream CUDA stream used for this operation. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_bwd_qkvpacked( void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor QKV, const NVTETensor QKV,
...@@ -199,31 +235,32 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -199,31 +235,32 @@ void nvte_fused_attn_bwd_qkvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] Bias The Bias tensor. * \param[in] Bias The Bias tensor.
* \param[in,out] S The S tensor. * \param[in,out] S The S tensor.
* \param[out] O The output O tensor. * \param[out] O The output O tensor.
* \param[out] Aux_Output_Tensors Auxiliary output tensors when training, e.g. M, ZInv. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * e.g. M, ZInv, rng_state.
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing * \param[in] rng_state Seed and offset of CUDA random number generator.
* for Q. It may be >= max(cu_seqlens_q). * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] max_seqlen_kv Max sequence length used for computing * it may be >= max(cu_seqlens_q).
* for KV. It may be >= max(cu_seqlens_kv). * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* \param[in] is_training Whether this is in training mode or inference. * it may be >= max(cu_seqlens_kv).
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] dropout Dropout probability. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] dropout Dropout probability.
* \param[in] bias_type Bias type. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] attn_mask_type Attention mask type. * \param[in] bias_type Bias type.
* \param[in] workspace Workspace tensor. * \param[in] attn_mask_type Attention mask type.
* \param[in] stream CUDA stream used for this operation. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor Q,
...@@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -231,7 +268,7 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Bias, const NVTETensor Bias,
NVTETensor S, NVTETensor S,
NVTETensor O, NVTETensor O,
NVTETensorPack* Aux_Output_Tensors, NVTETensorPack* Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_kv,
const NVTETensor rng_state, const NVTETensor rng_state,
...@@ -246,33 +283,34 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -246,33 +283,34 @@ void nvte_fused_attn_fwd_kvpacked(
* *
* Support Matrix: * Support Matrix:
\verbatim \verbatim
| precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim |
| FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 | | 0 | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | Yes | <= 512 | 64 |
\endverbatim \endverbatim
* *
* \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim]. * \param[in] Q The Q tensor, [total_seqs_q, num_heads, head_dim].
* \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim]. * \param[in] KV The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].
* \param[in] O The O tensor from forward. * \param[in] O The O tensor from forward.
* \param[in] dO The gradient of the O tensor. * \param[in] dO The gradient of the O tensor.
* \param[in] S The S tensor. * \param[in] S The S tensor.
* \param[in,out] dP The gradient of the P tensor. * \param[in,out] dP The gradient of the P tensor.
* \param[in] Aux_CTX_Tensors Auxiliary tensors from forward when in training mode. * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode,
* \param[out] dQ The gradient of the Q tensor. * e.g. M, ZInv, rng_state.
* \param[out] dKV The gradient of the KV tensor. * \param[out] dQ The gradient of the Q tensor.
* \param[out] dBias The gradient of the Bias tensor. * \param[out] dKV The gradient of the KV tensor.
* \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1]. * \param[out] dBias The gradient of the Bias tensor.
* \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q Accumulative sequence lengths for Q, [batch_size + 1].
* \param[in] max_seqlen_q Max sequence length used for computing * \param[in] cu_seqlens_kv Accumulative sequence lengths for KV, [batch_size + 1].
* for Q. It may be >= max(cu_seqlens_q). * \param[in] max_seqlen_q Max sequence length used for computing for Q.
* \param[in] max_seqlen_kv Max sequence length used for computing * it may be >= max(cu_seqlens_q).
* for KV. It may be >= max(cu_seqlens_kv). * \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* \param[in] attn_scale Scaling factor for Q * K.T. * it may be >= max(cu_seqlens_kv).
* \param[in] dropout Dropout probability. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] dropout Dropout probability.
* \param[in] bias_type Bias type. * \param[in] qkv_layout QKV tensor's layout.
* \param[in] attn_mask_type Attention mask type. * \param[in] bias_type Bias type.
* \param[in] workspace Workspace tensor. * \param[in] attn_mask_type Attention mask type.
* \param[in] stream CUDA stream used for this operation. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor Q,
......
This diff is collapsed.
...@@ -22,7 +22,7 @@ TE_DType = { ...@@ -22,7 +22,7 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16, torch.bfloat16: tex.DType.kBFloat16,
} }
AttnMaskTypes = ("causal", "padding") AttnMaskTypes = ("causal", "padding", "no_mask")
AttnTypes = ("self", "cross") AttnTypes = ("self", "cross")
......
...@@ -58,7 +58,10 @@ enum FP8FwdTensors { ...@@ -58,7 +58,10 @@ enum FP8FwdTensors {
GEMM1_OUTPUT = 2, GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3, GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4, GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5 GEMM2_OUTPUT = 5,
GEMM3_INPUT = 6,
GEMM3_WEIGHT = 7,
GEMM3_OUTPUT = 8
}; };
// Used as named indices on the `scale`, `scale_inv`, // Used as named indices on the `scale`, `scale_inv`,
...@@ -67,7 +70,9 @@ enum FP8BwdTensors { ...@@ -67,7 +70,9 @@ enum FP8BwdTensors {
GRAD_OUTPUT1 = 0, GRAD_OUTPUT1 = 0,
GRAD_INPUT1 = 1, GRAD_INPUT1 = 1,
GRAD_OUTPUT2 = 2, GRAD_OUTPUT2 = 2,
GRAD_INPUT2 = 3 GRAD_INPUT2 = 3,
GRAD_OUTPUT3 = 4,
GRAD_INPUT3 = 5
}; };
...@@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, ...@@ -81,6 +86,9 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt32: case transformer_engine::DType::kInt32:
return torch::kInt32;
case transformer_engine::DType::kInt64:
return torch::kInt64;
case transformer_engine::DType::kFloat32: case transformer_engine::DType::kFloat32:
return at::kFloat; return at::kFloat;
case transformer_engine::DType::kFloat16: case transformer_engine::DType::kFloat16:
......
...@@ -7,17 +7,22 @@ ...@@ -7,17 +7,22 @@
#include "common.h" #include "common.h"
#include "../common.h" #include "../common.h"
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string qkv_layout); NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype,
NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type); const transformer_engine::DType kv_dtype,
NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type); NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
float p_dropout, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim);
std::vector<at::Tensor> fused_attn_fwd_qkvpacked( std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, size_t h, size_t d, bool is_training,
bool is_training, float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const transformer_engine::DType qkv_type, const transformer_engine::DType qkv_type,
...@@ -27,13 +32,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked( ...@@ -27,13 +32,16 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen); const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_qkvpacked( std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
size_t b, size_t max_seqlen, size_t total_seqs, size_t b, size_t max_seqlen, size_t total_seqs,
size_t h, size_t d, size_t h, size_t d, float attn_scale,
float attn_scale, float p_dropout, bool set_zero, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens, const at::Tensor cu_seqlens,
const at::Tensor QKV, const at::Tensor QKV,
const at::Tensor O, const at::Tensor O,
...@@ -53,9 +61,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked( ...@@ -53,9 +61,11 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
std::vector<at::Tensor> fused_attn_fwd_kvpacked( std::vector<at::Tensor> fused_attn_fwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv, size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, size_t h, size_t d, bool is_training,
bool is_training, float attn_scale, float p_dropout, bool set_zero, float attn_scale, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor Q,
...@@ -67,14 +77,17 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked( ...@@ -67,14 +77,17 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
c10::optional<at::Tensor> amax_S, c10::optional<at::Tensor> amax_S,
c10::optional<at::Tensor> amax_O, c10::optional<at::Tensor> amax_O,
const c10::optional<at::Tensor> Bias, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen); const c10::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread);
std::vector<at::Tensor> fused_attn_bwd_kvpacked( std::vector<at::Tensor> fused_attn_bwd_kvpacked(
size_t b, size_t max_seqlen_q, size_t max_seqlen_kv, size_t b, size_t max_seqlen_q, size_t max_seqlen_kv,
size_t total_seqs_q, size_t total_seqs_kv, size_t total_seqs_q, size_t total_seqs_kv,
size_t h, size_t d, size_t h, size_t d, float attn_scale,
float attn_scale, float p_dropout, bool set_zero, float p_dropout, bool set_zero,
std::string qkv_layout, std::string bias_type, std::string attn_mask_type, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const at::Tensor cu_seqlens_kv,
const at::Tensor Q, const at::Tensor Q,
......
...@@ -403,6 +403,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -403,6 +403,9 @@ class TransformerLayer(torch.nn.Module):
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None, inference_params: Optional[Any] = None,
rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
...@@ -445,6 +448,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -445,6 +448,12 @@ class TransformerLayer(torch.nn.Module):
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None`
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias`
Bias type, {`no_bias`, `pre_scale_bias`, 'post_scale_bias`}
core_attention_bias: Optional[torch.Tensor], default = `None`
Bias tensor for Q * K.T
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
""" """
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
...@@ -473,6 +482,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -473,6 +482,9 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
rotary_pos_emb=rotary_pos_emb, rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
...@@ -516,6 +528,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -516,6 +528,9 @@ class TransformerLayer(torch.nn.Module):
encoder_output=encoder_output, encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
fast_zero_fill=fast_zero_fill,
) )
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
attention_output, attention_bias, residual = inter_attention_outputs attention_output, attention_bias, residual = inter_attention_outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment