"vscode:/vscode.git/clone" did not exist on "7f22f90e8cb423fdaa35203d41badd734d9c2e86"
Unverified Commit 26aad6b0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Disable cuDNN attention for known IMA and NaNs (#2344)



* Fix cuDNN backend selection for more case. Add CG as a option as well
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix logic
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuDNN checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add more checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cuddn version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix error message
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add check for window size
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f62cad90
...@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit) { int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -166,7 +166,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -166,7 +166,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
// 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
...@@ -407,6 +407,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -407,6 +407,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
" Please upgrade your cuDNN version if possible." " Please upgrade your cuDNN version if possible."
<< std::endl; << std::endl;
} }
if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) &&
(attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) &&
(attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Given combination of attention mask (non-causal) and "
"max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. "
" Please upgrade your cuDNN version if possible."
<< std::endl;
}
if ((cudnn_runtime_version <= 91500) && is_training &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
(max_seqlen_kv % 128 != 0) && cuda_graph &&
(attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) &&
(attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
(attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: Given combination of attention mask (non-padding),"
" max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for"
" backward fused attention with graph capture requires cuDNN 9.15.1+. "
"Please upgrade your cuDNN version if possible."
<< std::endl;
}
} else { } else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
} }
...@@ -419,11 +441,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -419,11 +441,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit, size_t max_seqlen, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, bool cuda_graph, float attn_scale, float dropout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, NVTETensor workspace, int64_t window_size_left, int64_t window_size_right,
cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -460,7 +482,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -460,7 +482,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit); h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit,
cuda_graph);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -496,16 +519,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -496,16 +519,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
} }
} }
// NVTE fused attention BWD with packed QKV // NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor S, NVTETensor dP, const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
NVTETensor dBias, NVTETensor dSoftmaxOffset, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, float attn_scale, float dropout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -544,7 +565,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -544,7 +565,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false); max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -602,10 +623,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -602,10 +623,10 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
NVTETensor workspace, cudaStream_t stream) { int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -681,7 +702,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -681,7 +702,7 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
return_max_logit); return_max_logit, cuda_graph);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -728,7 +749,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -728,7 +749,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -776,9 +798,10 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -776,9 +798,10 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype); const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend =
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false); softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
d, window_size_left, window_size_right, false, cuda_graph);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -833,16 +856,19 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -833,16 +856,19 @@ void nvte_fused_attn_bwd_kvpacked(
} }
} }
// NVTE fused attention FWD with separate Q, K and V // NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd( void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, const NVTETensor page_table_v, const NVTETensor rng_state,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -913,7 +939,7 @@ void nvte_fused_attn_fwd( ...@@ -913,7 +939,7 @@ void nvte_fused_attn_fwd(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit); return_max_logit, cuda_graph);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -963,7 +989,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -963,7 +989,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream) { bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -1008,7 +1034,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1008,7 +1034,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false); h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false,
cuda_graph);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
......
...@@ -207,13 +207,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -207,13 +207,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit); int64_t window_size_right, bool return_max_logit, bool cuda_graph);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -257,6 +258,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -257,6 +258,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* it may be >= max(seqlen_i) for i=0,...batch_size-1. * it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -273,11 +275,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -273,11 +275,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit, size_t max_seqlen, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, bool cuda_graph, float attn_scale, float dropout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, NVTETensor workspace, int64_t window_size_left, int64_t window_size_right,
cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
...@@ -324,19 +326,18 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -324,19 +326,18 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(
const NVTETensor S, NVTETensor dP, const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
NVTETensor dBias, NVTETensor dSoftmaxOffset, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
size_t max_seqlen, float attn_scale, float dropout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
* *
...@@ -387,6 +388,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -387,6 +388,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensor's layout. * \param[in] qkv_layout QKV tensor's layout.
...@@ -405,10 +407,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -405,10 +407,10 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
NVTETensor workspace, cudaStream_t stream); int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -461,6 +463,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -461,6 +463,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -472,7 +475,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -472,7 +475,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -527,6 +531,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -527,6 +531,7 @@ void nvte_fused_attn_bwd_kvpacked(
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference. * \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability. * \param[in] dropout Dropout probability.
* \param[in] qkv_layout QKV tensors' layout. * \param[in] qkv_layout QKV tensors' layout.
...@@ -545,9 +550,9 @@ void nvte_fused_attn_fwd( ...@@ -545,9 +550,9 @@ void nvte_fused_attn_fwd(
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
...@@ -605,6 +610,7 @@ void nvte_fused_attn_fwd( ...@@ -605,6 +610,7 @@ void nvte_fused_attn_fwd(
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -619,7 +625,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -619,7 +625,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream); bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset. /*! \brief Update the RNG state with the seed and calculated offset.
* *
......
...@@ -23,7 +23,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -23,7 +23,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false); false, false);
return backend; return backend;
} }
...@@ -180,7 +180,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -180,7 +180,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(),
nullptr); nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
...@@ -189,7 +189,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -189,7 +189,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
...@@ -199,7 +199,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -199,7 +199,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, softmax_type, window_size_left, window_size_right, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr); query_workspace_tensor.data(), nullptr);
} else { } else {
...@@ -279,7 +279,7 @@ static void FusedAttnForwardImpl( ...@@ -279,7 +279,7 @@ static void FusedAttnForwardImpl(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false); false, false);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -298,7 +298,7 @@ static void FusedAttnForwardImpl( ...@@ -298,7 +298,7 @@ static void FusedAttnForwardImpl(
qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
...@@ -311,7 +311,7 @@ static void FusedAttnForwardImpl( ...@@ -311,7 +311,7 @@ static void FusedAttnForwardImpl(
s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
...@@ -326,9 +326,9 @@ static void FusedAttnForwardImpl( ...@@ -326,9 +326,9 @@ static void FusedAttnForwardImpl(
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -480,7 +480,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -480,7 +480,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, query_workspace_tensor.data(), nullptr); deterministic, false, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
...@@ -491,19 +491,19 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -491,19 +491,19 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr); window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(
doutput_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
s_tensor.data(), // not used for F16 doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), s_tensor.data(), // not used for F16
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, query_workspace_tensor.data(), nullptr); window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -546,7 +546,7 @@ static void FusedAttnBackwardImpl( ...@@ -546,7 +546,7 @@ static void FusedAttnBackwardImpl(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false); false, false);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
...@@ -568,7 +568,7 @@ static void FusedAttnBackwardImpl( ...@@ -568,7 +568,7 @@ static void FusedAttnBackwardImpl(
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, dropout_probability, qkv_layout, bias_type, mask_type,
softmax_type, window_size_left, window_size_right, deterministic, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream); false, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = auto kv_shape =
...@@ -590,7 +590,7 @@ static void FusedAttnBackwardImpl( ...@@ -590,7 +590,7 @@ static void FusedAttnBackwardImpl(
dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic, mask_type, softmax_type, window_size_left, window_size_right, deterministic, false,
workspace_tensor.data(), stream); workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
...@@ -617,7 +617,7 @@ static void FusedAttnBackwardImpl( ...@@ -617,7 +617,7 @@ static void FusedAttnBackwardImpl(
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right, deterministic, mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream); false, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
...@@ -66,6 +66,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import ( ...@@ -66,6 +66,7 @@ from transformer_engine.pytorch.attention.dot_product_attention.utils import (
) )
from transformer_engine.pytorch import export from transformer_engine.pytorch import export
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.graph import is_graph_capturing
# Global vars for flash attn v2 and v3 imports # Global vars for flash attn v2 and v3 imports
flash_attn_cuda_bwd = None flash_attn_cuda_bwd = None
...@@ -1199,6 +1200,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1199,6 +1200,7 @@ class FusedAttnFunc(torch.autograd.Function):
window_size, window_size,
rng_gen, rng_gen,
softmax_offset, softmax_offset,
cuda_graph=is_graph_capturing(),
) )
# out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16
...@@ -1276,6 +1278,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1276,6 +1278,7 @@ class FusedAttnFunc(torch.autograd.Function):
rng_gen, rng_gen,
softmax_offset, softmax_offset,
return_max_logit, return_max_logit,
is_graph_capturing(),
) )
out = out_ out = out_
out_ret = out_ out_ret = out_
...@@ -1515,6 +1518,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1515,6 +1518,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.softmax_type, ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
is_graph_capturing(),
) )
# dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16
...@@ -1579,6 +1583,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1579,6 +1583,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.softmax_type, ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.deterministic, ctx.deterministic,
is_graph_capturing(),
) )
d_bias = None d_bias = None
......
...@@ -23,6 +23,7 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager ...@@ -23,6 +23,7 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage
from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.jit import jit_fuser
from transformer_engine.pytorch.graph import is_graph_capturing
from transformer_engine.pytorch.constants import ( from transformer_engine.pytorch.constants import (
dist_group_type, dist_group_type,
TE_DType, TE_DType,
...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -33,6 +34,7 @@ from transformer_engine.pytorch.distributed import (
gather_along_first_dim, gather_along_first_dim,
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
) )
from transformer_engine.pytorch.quantized_tensor import ( from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -715,6 +717,7 @@ def cp_p2p_fwd_fused_attn( ...@@ -715,6 +717,7 @@ def cp_p2p_fwd_fused_attn(
cu_seqlens_kv_padded=cu_seqlens_kv_padded_, cu_seqlens_kv_padded=cu_seqlens_kv_padded_,
**fp8_meta_kwargs, **fp8_meta_kwargs,
return_max_logit=return_max_logit, return_max_logit=return_max_logit,
cuda_graph=is_graph_capturing(),
) )
if fp8: if fp8:
...@@ -977,6 +980,7 @@ def cp_p2p_bwd_fused_attn( ...@@ -977,6 +980,7 @@ def cp_p2p_bwd_fused_attn(
attn_mask_type=attn_mask_type_, attn_mask_type=attn_mask_type_,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
deterministic=deterministic, deterministic=deterministic,
cuda_graph=is_graph_capturing(),
**fp8_meta_kwargs, **fp8_meta_kwargs,
) )
...@@ -2772,6 +2776,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2772,6 +2776,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i],
window_size=window_size_per_step[i], window_size=window_size_per_step[i],
return_max_logit=return_max_logit, return_max_logit=return_max_logit,
cuda_graph=is_graph_capturing(),
) )
if return_max_logit: if return_max_logit:
max_logit_per_step[i] = max_logit_[0] max_logit_per_step[i] = max_logit_[0]
...@@ -2986,6 +2991,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function): ...@@ -2986,6 +2991,7 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
attn_bias_type=ctx.attn_bias_type, attn_bias_type=ctx.attn_bias_type,
window_size=window_size_per_step[i], window_size=window_size_per_step[i],
deterministic=ctx.deterministic, deterministic=ctx.deterministic,
cuda_graph=is_graph_capturing(),
) )
else: else:
dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ dq_per_step[i], dk_per_step[i], dv_per_step[i] = [
...@@ -3282,6 +3288,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3282,6 +3288,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
softmax_type=softmax_type, softmax_type=softmax_type,
softmax_offset=softmax_offset, softmax_offset=softmax_offset,
return_max_logit=return_max_logit, return_max_logit=return_max_logit,
cuda_graph=is_graph_capturing(),
) )
if isinstance(out_, Float8Tensor): if isinstance(out_, Float8Tensor):
out_fp8 = out_ out_fp8 = out_
...@@ -3559,6 +3566,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function): ...@@ -3559,6 +3566,7 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
attn_bias_type=ctx.attn_bias_type, attn_bias_type=ctx.attn_bias_type,
window_size=ctx.window_size, window_size=ctx.window_size,
deterministic=ctx.deterministic, deterministic=ctx.deterministic,
cuda_graph=is_graph_capturing(),
**fp8_meta_kwargs, **fp8_meta_kwargs,
softmax_type=ctx.softmax_type, softmax_type=ctx.softmax_type,
) )
......
...@@ -1314,6 +1314,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1314,6 +1314,7 @@ class DotProductAttention(TransformerEngineBaseModule):
inference_params=inference_params, inference_params=inference_params,
softmax_type=self.softmax_type, softmax_type=self.softmax_type,
return_max_logit=self.return_max_logit, return_max_logit=self.return_max_logit,
cuda_graph=is_graph_capturing(),
) )
global _attention_backends global _attention_backends
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
......
...@@ -231,6 +231,8 @@ class AttentionParams: ...@@ -231,6 +231,8 @@ class AttentionParams:
The type of softmax operation. See DotProductAttention for details. The type of softmax operation. See DotProductAttention for details.
return_max_logit: bool, default = `False` return_max_logit: bool, default = `False`
Whether to output max_logit. Whether to output max_logit.
cuda_graph: bool, default = `False`
Whether support for cuda graph capture is needed or not.
""" """
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
...@@ -260,6 +262,7 @@ class AttentionParams: ...@@ -260,6 +262,7 @@ class AttentionParams:
inference_params: Optional[InferenceParams] = None inference_params: Optional[InferenceParams] = None
softmax_type: str = "vanilla" softmax_type: str = "vanilla"
return_max_logit: bool = False return_max_logit: bool = False
cuda_graph: bool = False
def __eq__(self, other): def __eq__(self, other):
""" """
...@@ -334,6 +337,7 @@ def get_attention_backend( ...@@ -334,6 +337,7 @@ def get_attention_backend(
inference_params = attention_params.inference_params inference_params = attention_params.inference_params
softmax_type = attention_params.softmax_type softmax_type = attention_params.softmax_type
return_max_logit = attention_params.return_max_logit return_max_logit = attention_params.return_max_logit
cuda_graph = attention_params.cuda_graph
# Run config # Run config
logger = logging.getLogger("DotProductAttention") logger = logging.getLogger("DotProductAttention")
...@@ -979,6 +983,7 @@ def get_attention_backend( ...@@ -979,6 +983,7 @@ def get_attention_backend(
window_size[0], window_size[0],
window_size[1], window_size[1],
return_max_logit, return_max_logit,
cuda_graph,
) )
if fused_attention_backend == FusedAttnBackend["No_Backend"]: if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input") logger.debug("Disabling FusedAttention as no backend supports the provided input")
......
...@@ -140,6 +140,7 @@ def fused_attn_fwd( ...@@ -140,6 +140,7 @@ def fused_attn_fwd(
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None, softmax_offset: torch.Tensor = None,
return_max_logit: bool = False, return_max_logit: bool = False,
cuda_graph: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention FWD for separate QKV input. """Fused Attention FWD for separate QKV input.
...@@ -219,6 +220,8 @@ def fused_attn_fwd( ...@@ -219,6 +220,8 @@ def fused_attn_fwd(
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
return_max_logit: bool, default = False return_max_logit: bool, default = False
whether to return the maximum attention score whether to return the maximum attention score
cuda_graph: bool, default = False
whether or not cuda graph capture is enabled.
Returns Returns
---------- ----------
...@@ -320,6 +323,7 @@ def fused_attn_fwd( ...@@ -320,6 +323,7 @@ def fused_attn_fwd(
rng_gen, rng_gen,
rng_elts_per_thread, rng_elts_per_thread,
return_max_logit, return_max_logit,
cuda_graph,
) )
if return_max_logit: if return_max_logit:
...@@ -367,6 +371,7 @@ def fused_attn_bwd( ...@@ -367,6 +371,7 @@ def fused_attn_bwd(
softmax_type: str = "vanilla", softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
deterministic: bool = False, deterministic: bool = False,
cuda_graph: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused Attention BWD for packed KV input. """Fused Attention BWD for packed KV input.
...@@ -439,6 +444,8 @@ def fused_attn_bwd( ...@@ -439,6 +444,8 @@ def fused_attn_bwd(
window and causal mask specifically. window and causal mask specifically.
deterministic: bool, default = False deterministic: bool, default = False
whether to execute the backward pass with deterministic behaviours. whether to execute the backward pass with deterministic behaviours.
cuda_graph: bool, default = False
whether or not cuda graph capture is enabled.
Returns Returns
---------- ----------
...@@ -509,6 +516,7 @@ def fused_attn_bwd( ...@@ -509,6 +516,7 @@ def fused_attn_bwd(
s_quantizer, s_quantizer,
dp_quantizer, dp_quantizer,
dqkv_quantizer, dqkv_quantizer,
cuda_graph,
) )
return output_tensors return output_tensors
...@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit); int64_t window_size_right, bool return_max_logit, bool cuda_graph);
std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer, std::pair<TensorWrapper, py::object> quantizer_helper(py::handle quantizer,
const std::vector<size_t> &shape, DType dtype, const std::vector<size_t> &shape, DType dtype,
...@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -94,7 +94,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread, bool return_max_logit); size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph);
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
...@@ -106,7 +106,7 @@ std::vector<py::object> fused_attn_bwd( ...@@ -106,7 +106,7 @@ std::vector<py::object> fused_attn_bwd(
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer); py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph);
at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
......
...@@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit) { int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right,
return_max_logit); return_max_logit, cuda_graph);
return fused_attention_backend; return fused_attention_backend;
} }
...@@ -107,7 +107,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -107,7 +107,7 @@ std::vector<py::object> fused_attn_fwd(
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen, const std::optional<at::Tensor> SoftmaxOffset, const std::optional<at::Generator> rng_gen,
size_t rng_elts_per_thread, bool return_max_logit) { size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) {
auto none = py::none(); auto none = py::none();
// create QKV tensor wrappers // create QKV tensor wrappers
...@@ -229,7 +229,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -229,7 +229,7 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(), softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
...@@ -289,7 +289,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -289,7 +289,7 @@ std::vector<py::object> fused_attn_fwd(
te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(), softmax_type, window_size[0], window_size[1], workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
...@@ -312,7 +312,7 @@ std::vector<py::object> fused_attn_bwd( ...@@ -312,7 +312,7 @@ std::vector<py::object> fused_attn_bwd(
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) { py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) {
auto none = py::none(); auto none = py::none();
// create QKV, O, dO tensor wrappers // create QKV, O, dO tensor wrappers
...@@ -527,13 +527,14 @@ std::vector<py::object> fused_attn_bwd( ...@@ -527,13 +527,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd( nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// allocate memory for workspace // allocate memory for workspace
...@@ -543,13 +544,14 @@ std::vector<py::object> fused_attn_bwd( ...@@ -543,13 +544,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel // execute kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd( nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
}); });
// destroy tensor wrappers // destroy tensor wrappers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment