Unverified Commit c6a92a4d authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Add support for SWA (left, right) with FusedAttention (#2477)

* SWA (left, right) with FusedAttention changes cherry-picked from https://github.com/NVIDIA/TransformerEngine/pull/1369

Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix test_kv_cache failures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove unnecessary comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix some more filter issues, address feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make conditions more accurate
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add cp tests to test swa (left, right)
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove dead code and make conditions better
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feedback form Charlene
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small er
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* plumb `bottom_right_diagonal` through jax
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* plumb `bottom_right_diagonal` through jax
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing fields
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* use proper mask type in CP
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0f0e229b
...@@ -153,6 +153,7 @@ def test_dot_product_attention( ...@@ -153,6 +153,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa: if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2] config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type: if qkv_format == "thd" and "padding" not in config.attn_mask_type:
...@@ -171,6 +172,7 @@ def test_dot_product_attention( ...@@ -171,6 +172,7 @@ def test_dot_product_attention(
deterministic=_deterministic, deterministic=_deterministic,
) )
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported: if not fused_attn_supported:
is_training = False is_training = False
available_backends, _, fused_attn_backends = get_available_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
...@@ -701,9 +703,10 @@ model_configs_swa = { ...@@ -701,9 +703,10 @@ model_configs_swa = {
@pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys()) @pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model): @pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"])
def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with sliding window attention""" """Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
model_configs_alibi_slopes = { model_configs_alibi_slopes = {
......
...@@ -147,7 +147,7 @@ model_configs_fused_attn = { ...@@ -147,7 +147,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA ), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig( "cp_2_2": ModelConfig(
...@@ -163,7 +163,7 @@ model_configs_fused_attn = { ...@@ -163,7 +163,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA ), # GQA
"cp_2_4": ModelConfig( "cp_2_4": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
), # GQA ), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
...@@ -187,7 +187,16 @@ dtypes = ["bf16", "fp16", "fp8"] ...@@ -187,7 +187,16 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"] qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential: if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] configs = [
"cp_1_0",
"cp_1_1",
"cp_1_4",
"cp_2_0",
"cp_2_2",
"cp_2_4",
"cp_3_2",
"cp_4_2",
]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"] dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"] qkv_formats = ["sbhd", "thd"]
......
...@@ -353,11 +353,11 @@ def get_available_attention_backends( ...@@ -353,11 +353,11 @@ def get_available_attention_backends(
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False: if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging() AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3): for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True _attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test() available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]: if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend) fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends return available_backends, flash_attention_backend, fused_attn_backends
...@@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(window_size_right == -1 || window_size_right == 0)) || (window_size_right == -1 || window_size_right == 0)) ||
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
(cudnn_runtime_version >= 90200 && (cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left == -1 && window_size_right == -1 &&
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) && max_seqlen_q == max_seqlen_kv)) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && max_seqlen_q <= max_seqlen_kv && dropout == 0.0 &&
...@@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd}
(cudnn_runtime_version >= 90600 && (cudnn_runtime_version >= 90600 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && ((window_size_left >= 0 || window_size_left == -1) &&
(window_size_right >= 0 || window_size_right == -1) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100 // TODO(cyang): fix bug for BRCM + cross-attention on sm100
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) || cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700)))) || cudnn_runtime_version > 90700)))) ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
...@@ -515,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -515,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV // NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated. // DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. // Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead.
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
size_t max_seqlen, bool is_training, bool return_max_logit, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -598,10 +600,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -598,10 +600,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr,
wkspace, stream, handle); input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
...@@ -639,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -639,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
NVTETensor workspace, cudaStream_t stream) { bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -736,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -736,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
&K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias,
&dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias,
input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention " "cuDNN 8.9.0 is required for BF16/FP16 fused attention "
...@@ -790,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -790,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -902,10 +906,10 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -902,10 +906,10 @@ void nvte_fused_attn_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_v, input_rng_state, wkspace, stream, handle); input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. " "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. "
...@@ -945,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -945,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
cudaStream_t stream) { NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -1052,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -1052,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO,
output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
const char *err_msg = const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention " "cuDNN 8.9.3 is required for BF16/FP16 fused attention "
...@@ -1106,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1106,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, int64_t window_size_left, int64_t window_size_right,
cudaStream_t stream) { bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -1195,10 +1199,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1195,10 +1199,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_v, input_rng_state, wkspace, stream, handle); input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else #else
NVTE_ERROR( NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. " "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
...@@ -1228,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1228,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, int64_t window_size_left, int64_t window_size_right,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
...@@ -1302,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1302,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO,
input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else #else
......
...@@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1,
void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_bottom_right && s_q == s_kv && !is_padding) { if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true; is_causal = true;
is_bottom_right = false; is_bottom_right = false;
bottom_right_diagonal = false;
} }
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_dropout = (is_training && dropout_probability != 0.0f);
...@@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
softmax_type, softmax_type,
window_size_left, window_size_left,
window_size_right, window_size_right,
bottom_right_diagonal,
true, true,
tensorType, tensorType,
cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET,
...@@ -254,9 +256,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -254,9 +256,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
fe::DiagonalAlignment_t const &diagonal_alignment =
bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT
: fe::DiagonalAlignment_t::TOP_LEFT;
sdpa_options.set_diagonal_alignment(diagonal_alignment);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) { if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); sdpa_options.set_diagonal_band_left_bound(window_size_left + 1);
} }
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
sdpa_options.set_diagonal_band_right_bound(window_size_right);
}
sdpa_options.set_alibi_mask(is_alibi); sdpa_options.set_alibi_mask(is_alibi);
...@@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
...@@ -563,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -563,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bottom_right && s_q == s_kv && !is_padding) { if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true; is_causal = true;
is_bottom_right = false; is_bottom_right = false;
bottom_right_diagonal = false;
} }
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
...@@ -621,6 +632,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -621,6 +632,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
softmax_type, softmax_type,
window_size_left, window_size_left,
window_size_right, window_size_right,
bottom_right_diagonal,
deterministic, deterministic,
tensorType, tensorType,
cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET,
...@@ -781,9 +793,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -781,9 +793,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_max_total_seq_len_kv(s_kv); sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
} }
fe::DiagonalAlignment_t const &diagonal_alignment =
bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT
: fe::DiagonalAlignment_t::TOP_LEFT;
sdpa_backward_options.set_diagonal_alignment(diagonal_alignment);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) { if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1);
} }
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
sdpa_backward_options.set_diagonal_band_right_bound(window_size_right);
}
if (cudnn_runtime_version >= 90000) { if (cudnn_runtime_version >= 90000) {
sdpa_backward_options.set_deterministic_algorithm(deterministic); sdpa_backward_options.set_deterministic_algorithm(deterministic);
...@@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
...@@ -1180,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1180,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd(
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV,
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
stream, handle); &workspace_size, stream, handle);
if (workspace_size > 0) { if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
...@@ -1206,13 +1226,14 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1206,13 +1226,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine; using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype; const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr; void *devPtrQ = input_Q->data.dptr;
...@@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle); workspace->data.dptr, &workspace_size, stream, handle);
......
...@@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
...@@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd( ...@@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900 #endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1( ...@@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1(
0, 0,
0, 0,
true, true,
true,
qkv_tensor_type, qkv_tensor_type,
o_tensor_type, o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET, cudnn_frontend::DataType_t::NOT_SET,
...@@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1( ...@@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1(
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0, 0,
0, 0,
true,
false, false,
qkv_tensor_type, qkv_tensor_type,
o_tensor_type, o_tensor_type,
......
...@@ -110,6 +110,7 @@ struct FADescriptor_v1 { ...@@ -110,6 +110,7 @@ struct FADescriptor_v1 {
NVTE_Softmax_Type softmax_type; NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left; std::int64_t window_size_left;
std::int64_t window_size_right; std::int64_t window_size_right;
bool bottom_right_diagonal;
bool deterministic; bool deterministic;
cudnn_frontend::DataType_t qkv_tensor_type; cudnn_frontend::DataType_t qkv_tensor_type;
cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t o_tensor_type;
...@@ -121,15 +122,16 @@ struct FADescriptor_v1 { ...@@ -121,15 +122,16 @@ struct FADescriptor_v1 {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic,
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type,
generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type,
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
} }
}; };
......
...@@ -270,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -270,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
[[deprecated( [[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]] "Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
size_t max_seqlen, bool is_training, bool return_max_logit, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
...@@ -333,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -333,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -347,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -347,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
NVTETensor workspace, cudaStream_t stream); bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
* *
...@@ -410,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -410,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -425,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -425,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -479,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -479,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -495,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -495,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -560,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -560,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd( void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, const NVTETensor page_table_v, const NVTETensor rng_state,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
...@@ -629,6 +636,7 @@ void nvte_fused_attn_fwd( ...@@ -629,6 +636,7 @@ void nvte_fused_attn_fwd(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -644,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -644,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, int64_t window_size_left, int64_t window_size_right,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream); bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset. /*! \brief Update the RNG state with the seed and calculated offset.
* *
......
...@@ -70,6 +70,7 @@ __all__ = [ ...@@ -70,6 +70,7 @@ __all__ = [
"is_training", "is_training",
"max_segments_per_seq", "max_segments_per_seq",
"window_size", "window_size",
"bottom_right_diagonal",
"context_parallel_load_balanced", "context_parallel_load_balanced",
"cp_axis", "cp_axis",
"cp_striped_window_size", "cp_striped_window_size",
...@@ -91,6 +92,7 @@ class _FusedAttnConfig: ...@@ -91,6 +92,7 @@ class _FusedAttnConfig:
is_training: bool is_training: bool
max_segments_per_seq: int max_segments_per_seq: int
window_size: Tuple[int, int] window_size: Tuple[int, int]
bottom_right_diagonal: bool
context_parallel_load_balanced: bool context_parallel_load_balanced: bool
cp_axis: str cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA
...@@ -371,6 +373,11 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -371,6 +373,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape *bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape) bias_batch = reduce(operator.mul, bias_batch_shape)
bottom_right_diagonal = config.attn_mask_type in [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend # prepare for the active fused-attn backend
input_batch = reduce(operator.mul, batch_shape) input_batch = reduce(operator.mul, batch_shape)
...@@ -395,6 +402,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -395,6 +402,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.max_segments_per_seq, config.max_segments_per_seq,
config.window_size[0], config.window_size[0],
config.window_size[1], config.window_size[1],
bottom_right_diagonal,
) )
wkspace_aval = q_aval.update( wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -503,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -503,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
bottom_right_diagonal=config.bottom_right_diagonal,
softmax_type=int(config.softmax_type.value), softmax_type=int(config.softmax_type.value),
) )
...@@ -813,6 +822,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -813,6 +822,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.max_segments_per_seq, config.max_segments_per_seq,
config.window_size[0], config.window_size[0],
config.window_size[1], config.window_size[1],
config.bottom_right_diagonal,
) )
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
...@@ -948,6 +958,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -948,6 +958,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left, window_size_left=window_size_left,
window_size_right=window_size_right, window_size_right=window_size_right,
bottom_right_diagonal=config.bottom_right_diagonal,
softmax_type=int(config.softmax_type.value), softmax_type=int(config.softmax_type.value),
) )
...@@ -1357,9 +1368,10 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1357,9 +1368,10 @@ class _FusedAttnCPWithAllGatherHelper:
def get_step_config(self) -> _FusedAttnConfig: def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention.""" """Returns a _FusedAttnConfig for single CP step call to fused attention."""
adjusted_mask = self.get_adjusted_mask()
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(), attn_mask_type=adjusted_mask,
softmax_type=self.config.softmax_type, softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout, qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
...@@ -1367,6 +1379,7 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1367,6 +1379,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training=self.config.is_training, is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq, max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size, window_size=self.config.window_size,
bottom_right_diagonal=adjusted_mask.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None, cp_striped_window_size=None,
...@@ -1375,9 +1388,10 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1375,9 +1388,10 @@ class _FusedAttnCPWithAllGatherHelper:
def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig: def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention.""" """Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
adjusted_mask = self.get_adjusted_mask()
return _FusedAttnConfig( return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type, attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(), attn_mask_type=adjusted_mask,
softmax_type=self.config.softmax_type, softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout, qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor, scaling_factor=self.config.scaling_factor,
...@@ -1385,6 +1399,7 @@ class _FusedAttnCPWithAllGatherHelper: ...@@ -1385,6 +1399,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training=self.config.is_training, is_training=self.config.is_training,
max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size), max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size),
window_size=self.config.window_size, window_size=self.config.window_size,
bottom_right_diagonal=adjusted_mask.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None, cp_striped_window_size=None,
...@@ -2430,6 +2445,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -2430,6 +2445,7 @@ class _FusedAttnCPWithP2PHelper:
is_training=self.config.is_training, is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq, max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size, window_size=self.config.window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced, context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis, cp_axis=self.config.cp_axis,
cp_striped_window_size=None, cp_striped_window_size=None,
...@@ -3418,6 +3434,7 @@ def fused_attn_fwd( ...@@ -3418,6 +3434,7 @@ def fused_attn_fwd(
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size, window_size=(-1, -1) if window_size is None else window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None, cp_striped_window_size=None,
...@@ -3590,6 +3607,7 @@ def fused_attn_bwd( ...@@ -3590,6 +3607,7 @@ def fused_attn_bwd(
is_training=is_training, is_training=is_training,
max_segments_per_seq=max_segments_per_seq, max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size, window_size=(-1, -1) if window_size is None else window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced, context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis), cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None, cp_striped_window_size=None,
......
...@@ -121,7 +121,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -121,7 +121,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right, bool bottom_right_diagonal);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
...@@ -129,7 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -129,7 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right); int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal);
// GEMM // GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......
...@@ -144,7 +144,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -144,7 +144,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right, bool bottom_right_diagonal) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -192,7 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -192,7 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(),
nullptr);
} }
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
...@@ -237,7 +238,7 @@ static void FusedAttnForwardImpl( ...@@ -237,7 +238,7 @@ static void FusedAttnForwardImpl(
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
...@@ -328,7 +329,7 @@ static void FusedAttnForwardImpl( ...@@ -328,7 +329,7 @@ static void FusedAttnForwardImpl(
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors); nvte_tensor_pack_destroy(&aux_output_tensors);
} }
...@@ -346,6 +347,7 @@ static void FusedAttnForwardImpl( ...@@ -346,6 +347,7 @@ static void FusedAttnForwardImpl(
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \ size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \ auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \ auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
bool bottom_right_diagonal = get_attr_value<bool>(attrs, "bottom_right_diagonal"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \ float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \ float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \ NVTE_Bias_Type bias_type = \
...@@ -384,7 +386,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty ...@@ -384,7 +386,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor, qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype, dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right); is_training, deterministic, window_size_left, window_size_right, bottom_right_diagonal);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -415,7 +417,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -415,7 +417,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
...@@ -467,17 +469,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -467,17 +469,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
nvte_fused_attn_bwd( nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(),
doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); window_size_right, bottom_right_diagonal, deterministic, false,
query_workspace_tensor.data(), nullptr);
} }
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
...@@ -496,7 +499,7 @@ static void FusedAttnBackwardImpl( ...@@ -496,7 +499,7 @@ static void FusedAttnBackwardImpl(
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
FUSED_ATTN_IMPL_COMMON_BLOCK; FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */ /* Input tensors */
...@@ -593,16 +596,17 @@ static void FusedAttnBackwardImpl( ...@@ -593,16 +596,17 @@ static void FusedAttnBackwardImpl(
} }
} }
nvte_fused_attn_bwd( nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(),
doutput_tensor.data(), s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), dbias_tensor.data(), dsoftmax_offset_tensor.data(),
dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors); nvte_tensor_pack_destroy(&aux_input_tensors);
} }
...@@ -631,7 +635,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T ...@@ -631,7 +635,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim,
max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type,
softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right); window_size_right, bottom_right_diagonal);
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
......
...@@ -261,6 +261,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -261,6 +261,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
...@@ -346,6 +347,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -346,6 +347,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
attention_type=self.attention_type, attention_type=self.attention_type,
bottom_right_alignment=(
attn_mask_type not in ["causal", "padding_causal"]
if bottom_right_diagonal is None
else bottom_right_diagonal
),
) )
) )
...@@ -449,7 +455,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -449,7 +455,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None, actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None, actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=(
attn_mask_type not in ["causal", "padding_causal"]
if bottom_right_diagonal is None
else bottom_right_diagonal
),
) )
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_result, matmul_result,
...@@ -1110,6 +1120,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1110,6 +1120,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type, attn_mask_type,
softmax_type, softmax_type,
window_size, window_size,
bottom_right_diagonal,
rng_gen, rng_gen,
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
...@@ -1213,6 +1224,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1213,6 +1224,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type, attn_mask_type,
softmax_type, softmax_type,
window_size, window_size,
bottom_right_diagonal,
rng_gen, rng_gen,
softmax_offset, softmax_offset,
cuda_graph=is_graph_capturing(), cuda_graph=is_graph_capturing(),
...@@ -1290,6 +1302,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1290,6 +1302,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type, attn_mask_type,
softmax_type, softmax_type,
window_size, window_size,
bottom_right_diagonal,
rng_gen, rng_gen,
softmax_offset, softmax_offset,
return_max_logit, return_max_logit,
...@@ -1377,6 +1390,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1377,6 +1390,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type ctx.softmax_type = softmax_type
ctx.window_size = window_size ctx.window_size = window_size
ctx.bottom_right_diagonal = bottom_right_diagonal
ctx.fused_attention_backend = ( ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
) )
...@@ -1527,6 +1541,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1527,6 +1541,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type, ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.bottom_right_diagonal,
ctx.deterministic, ctx.deterministic,
is_graph_capturing(), is_graph_capturing(),
) )
...@@ -1592,6 +1607,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1592,6 +1607,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type, ctx.attn_mask_type,
ctx.softmax_type, ctx.softmax_type,
ctx.window_size, ctx.window_size,
ctx.bottom_right_diagonal,
ctx.deterministic, ctx.deterministic,
is_graph_capturing(), is_graph_capturing(),
) )
...@@ -1631,6 +1647,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1631,6 +1647,7 @@ class FusedAttnFunc(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
d_softmax_offset, d_softmax_offset,
None, None,
None, None,
...@@ -1728,6 +1745,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1728,6 +1745,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -1935,6 +1953,7 @@ class FusedAttention(torch.nn.Module): ...@@ -1935,6 +1953,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type, attn_mask_type,
self.softmax_type, self.softmax_type,
window_size, window_size,
bottom_right_diagonal,
None, # rng_gen None, # rng_gen
fused_attention_backend, fused_attention_backend,
use_FAv2_bwd, use_FAv2_bwd,
......
...@@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in ``forward`` as well. be overridden by :attr:`window_size` in ``forward`` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
attention_type : str, default = "self" attention_type : str, default = "self"
type of attention, either ``"self"`` and ``"cross"``. type of attention, either ``"self"`` and ``"cross"``.
layer_number : int, default = None layer_number : int, default = None
...@@ -324,6 +329,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -324,6 +329,7 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_size: int = 1, tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
...@@ -350,6 +356,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -350,6 +356,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type = "padding_causal" attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
self.bottom_right_diagonal = bottom_right_diagonal
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -811,6 +818,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -811,6 +818,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv: int = None, max_seqlen_kv: int = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias", core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None, core_attention_bias: Optional[torch.Tensor] = None,
...@@ -963,6 +971,16 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -963,6 +971,16 @@ class DotProductAttention(TransformerEngineBaseModule):
causal masks are aligned to the bottom right corner. causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention. Sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = None
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
Note: This parameter will be automatically overridden based on the
`attn_mask_type` - it will be forced to `False` for 'causal' and
'padding_causal' mask types, and forced to `True` for mask types
containing 'bottom_right' (e.g., 'causal_bottom_right',
'padding_causal_bottom_right'), regardless of the explicitly passed value.
checkpoint_core_attention : bool, default = False checkpoint_core_attention : bool, default = False
If true, forward activations for attention are recomputed If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
...@@ -1081,6 +1099,15 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1081,6 +1099,15 @@ class DotProductAttention(TransformerEngineBaseModule):
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
# checks for qkv_format # checks for qkv_format
if qkv_format is None: if qkv_format is None:
...@@ -1144,6 +1171,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1144,6 +1171,8 @@ class DotProductAttention(TransformerEngineBaseModule):
assert "padding" in attn_mask_type, "KV caching requires padding mask!" assert "padding" in attn_mask_type, "KV caching requires padding mask!"
if attn_mask_type == "padding_causal": if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right" attn_mask_type = attn_mask_type + "_bottom_right"
# since attention mask is changed, set `bottom_right_diagonal` to True
bottom_right_diagonal = True
if self.attention_type != "cross": if self.attention_type != "cross":
self.fast_setattr("attention_type", "cross") self.fast_setattr("attention_type", "cross")
...@@ -1257,7 +1286,6 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1257,7 +1286,6 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.layer_number == 1: if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True
bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
if core_attention_bias_type == "alibi": if core_attention_bias_type == "alibi":
assert ( assert (
core_attention_bias is None core_attention_bias is None
...@@ -1266,7 +1294,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1266,7 +1294,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_num_heads"] != query_layer.shape[-2] _alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal
or _alibi_cache["_alibi_slopes"] is None or _alibi_cache["_alibi_slopes"] is None
): ):
_alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_slopes_require_update"] = True
...@@ -1323,6 +1351,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1323,6 +1351,7 @@ class DotProductAttention(TransformerEngineBaseModule):
head_dim_v=head_dim_v, head_dim_v=head_dim_v,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_shape=core_attention_bias_shape,
...@@ -1446,9 +1475,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1446,9 +1475,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_fused_attention: if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and ( if core_attention_bias_type == "alibi" and (alibi_slopes is not None):
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
):
fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = dpa_utils.get_alibi( _, fu_core_attention_bias = dpa_utils.get_alibi(
_alibi_cache, _alibi_cache,
...@@ -1457,7 +1484,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1457,7 +1484,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv, max_seqlen_kv,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype, bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], bottom_right_alignment=bottom_right_diagonal,
) )
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
...@@ -1475,6 +1502,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1475,6 +1502,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias, core_attention_bias=fu_core_attention_bias,
...@@ -1505,6 +1533,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1505,6 +1533,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
fused_attention_backend=fused_attention_backend, fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type, core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias, core_attention_bias=fu_core_attention_bias,
...@@ -1539,6 +1568,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1539,6 +1568,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
...@@ -1562,6 +1592,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1562,6 +1592,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_mask=attention_mask, attention_mask=attention_mask,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
......
...@@ -200,6 +200,9 @@ class AttentionParams: ...@@ -200,6 +200,9 @@ class AttentionParams:
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size : Tuple[int, int], default = None window_size : Tuple[int, int], default = None
Sliding window attention size. Sliding window attention size.
bottom_right_diagonal: bool, default = `None`
Whether to align sliding window and ALiBi diagonal to the bottom right corner
of the softmax matrix.
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type : str, default = no_bias core_attention_bias_type : str, default = no_bias
...@@ -249,6 +252,7 @@ class AttentionParams: ...@@ -249,6 +252,7 @@ class AttentionParams:
head_dim_v: int = 64 head_dim_v: int = 64
attn_mask_type: str = "no_mask" attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None window_size: Union[Tuple[int, int], None] = None
bottom_right_diagonal: bool = True
alibi_slopes_shape: Union[torch.Size, List, None] = None alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias" core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss" core_attention_bias_shape: str = "1hss"
...@@ -325,6 +329,7 @@ def get_attention_backend( ...@@ -325,6 +329,7 @@ def get_attention_backend(
head_dim_v = attention_params.head_dim_v head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size window_size = attention_params.window_size
bottom_right_diagonal = attention_params.bottom_right_diagonal
alibi_slopes_shape = attention_params.alibi_slopes_shape alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape core_attention_bias_shape = attention_params.core_attention_bias_shape
...@@ -859,39 +864,43 @@ def get_attention_backend( ...@@ -859,39 +864,43 @@ def get_attention_backend(
# backend | window_size | diagonal alignment # backend | window_size | diagonal alignment
# --------------------------------------------------------------------------------- # ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left # FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right
# | | converts window_size to an 'arbitrary' mask # | | converts window_size to an 'arbitrary' mask
if window_size is None: if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size) window_size = check_set_window_size(attn_mask_type, window_size)
else: if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): logger.debug(
logger.debug( "Disabling FusedAttention as it does not support sliding window attention for FP8"
"Disabling FusedAttention as it does not support sliding window attention" )
" for FP8" use_fused_attention = False
) elif attention_dropout != 0.0:
use_fused_attention = False logger.debug(
elif window_size[1] != 0 or attention_dropout != 0.0: "Disabling FusedAttention as it only supports sliding window attention "
logger.debug( "without dropout"
"Disabling FusedAttention as it only supports sliding window attention " )
"with (left, 0) and no dropout" use_fused_attention = False
) elif max_seqlen_q > max_seqlen_kv:
use_fused_attention = False logger.debug(
elif max_seqlen_q > max_seqlen_kv: "Disabling FusedAttention as it does not support sliding window attention "
logger.debug( "with s_q > s_kv for cross-attention"
"Disabling FusedAttention as it does not support sliding window attention " )
"with s_q > s_kv for cross-attention" use_fused_attention = False
) if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
use_fused_attention = False if not FlashAttentionUtils.is_installed:
if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): FlashAttentionUtils.version_required = PkgVersion("2.3")
if not FlashAttentionUtils.is_installed: elif not FlashAttentionUtils.v2_3_plus:
FlashAttentionUtils.version_required = PkgVersion("2.3") logger.debug(
elif not FlashAttentionUtils.v2_3_plus: "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
logger.debug( )
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" use_flash_attention_2 = False
) elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv:
use_flash_attention_2 = False logger.debug(
"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
# Filter: Attention bias # Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment # backend | bias types | ALiBi diagonal alignment
...@@ -913,6 +922,12 @@ def get_attention_backend( ...@@ -913,6 +922,12 @@ def get_attention_backend(
elif not FlashAttentionUtils.v2_4_plus: elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention_2 = False use_flash_attention_2 = False
elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
if ( if (
core_attention_bias_type not in ["no_bias", "alibi"] core_attention_bias_type not in ["no_bias", "alibi"]
...@@ -930,13 +945,12 @@ def get_attention_backend( ...@@ -930,13 +945,12 @@ def get_attention_backend(
if ( if (
use_fused_attention use_fused_attention
and core_attention_bias_type == "alibi" and core_attention_bias_type == "alibi"
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) and (alibi_slopes_shape is not None)
): ):
fu_core_attention_bias_type = "post_scale_bias" fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False fu_core_attention_bias_requires_grad = False
if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss" if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss" fu_core_attention_bias_shape = "1hss"
elif ( elif (
len(alibi_slopes_shape) == 2 len(alibi_slopes_shape) == 2
......
...@@ -31,6 +31,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -31,6 +31,7 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
...@@ -92,6 +93,11 @@ class MultiheadAttention(torch.nn.Module): ...@@ -92,6 +93,11 @@ class MultiheadAttention(torch.nn.Module):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can ``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in :meth:`forward` as well. be overridden by :attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
num_gqa_groups : int, default = None num_gqa_groups : int, default = None
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
...@@ -247,6 +253,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -247,6 +253,7 @@ class MultiheadAttention(torch.nn.Module):
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
...@@ -285,6 +292,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -285,6 +292,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.window_size = window_size self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.layer_number = 1 if layer_number is None else layer_number self.layer_number = 1 if layer_number is None else layer_number
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.attention_type = attention_type self.attention_type = attention_type
...@@ -621,6 +629,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -621,6 +629,7 @@ class MultiheadAttention(torch.nn.Module):
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None, attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -667,6 +676,11 @@ class MultiheadAttention(torch.nn.Module): ...@@ -667,6 +676,11 @@ class MultiheadAttention(torch.nn.Module):
aligned to the bottom right corner. aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
sliding window size for local attention. sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
``layer_type="decoder"``. ``layer_type="decoder"``.
...@@ -731,6 +745,17 @@ class MultiheadAttention(torch.nn.Module): ...@@ -731,6 +745,17 @@ class MultiheadAttention(torch.nn.Module):
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if "padding" in attn_mask_type and attention_mask is not None: if "padding" in attn_mask_type and attention_mask is not None:
for mask in attention_mask: for mask in attention_mask:
assert mask.dtype == torch.bool, "Attention mask must be in boolean type!" assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
...@@ -1001,6 +1026,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -1001,6 +1026,7 @@ class MultiheadAttention(torch.nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type, core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias, core_attention_bias=core_attention_bias,
......
...@@ -137,6 +137,7 @@ def fused_attn_fwd( ...@@ -137,6 +137,7 @@ def fused_attn_fwd(
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla", softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
bottom_right_diagonal: bool = None,
rng_gen: torch.Generator = None, rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None, softmax_offset: torch.Tensor = None,
return_max_logit: bool = False, return_max_logit: bool = False,
...@@ -212,6 +213,9 @@ def fused_attn_fwd( ...@@ -212,6 +213,9 @@ def fused_attn_fwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
rng_gen : torch.Generator, default = None rng_gen : torch.Generator, default = None
random number generator; random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
...@@ -255,6 +259,12 @@ def fused_attn_fwd( ...@@ -255,6 +259,12 @@ def fused_attn_fwd(
max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
""" """
if bottom_right_diagonal is None:
bottom_right_diagonal = attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
if attn_scale is None: if attn_scale is None:
d = q.size(-1) d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
...@@ -306,6 +316,7 @@ def fused_attn_fwd( ...@@ -306,6 +316,7 @@ def fused_attn_fwd(
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type], SoftmaxType[softmax_type],
window_size, window_size,
bottom_right_diagonal,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
q, q,
...@@ -370,6 +381,7 @@ def fused_attn_bwd( ...@@ -370,6 +381,7 @@ def fused_attn_bwd(
attn_mask_type: str = "padding", attn_mask_type: str = "padding",
softmax_type: str = "vanilla", softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1), window_size: Tuple[int, int] = (-1, -1),
bottom_right_diagonal: bool = None,
deterministic: bool = False, deterministic: bool = False,
cuda_graph: bool = False, cuda_graph: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
...@@ -442,6 +454,9 @@ def fused_attn_bwd( ...@@ -442,6 +454,9 @@ def fused_attn_bwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
deterministic : bool, default = False deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours. whether to execute the backward pass with deterministic behaviours.
cuda_graph : bool, default = False cuda_graph : bool, default = False
...@@ -462,6 +477,12 @@ def fused_attn_bwd( ...@@ -462,6 +477,12 @@ def fused_attn_bwd(
gradient tensor of softmax offset of shape [1, h_q, 1, 1]. gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details. See softmax_type in DotProductAttention for details.
""" """
if bottom_right_diagonal is None:
bottom_right_diagonal = attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
if attn_scale is None: if attn_scale is None:
d = q.size(-1) d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d) attn_scale = 1.0 / math.sqrt(d)
...@@ -500,6 +521,7 @@ def fused_attn_bwd( ...@@ -500,6 +521,7 @@ def fused_attn_bwd(
AttnMaskType[attn_mask_type], AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type], SoftmaxType[softmax_type],
window_size, window_size,
bottom_right_diagonal,
deterministic, deterministic,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_kv, cu_seqlens_kv,
......
...@@ -87,9 +87,10 @@ std::vector<py::object> fused_attn_fwd( ...@@ -87,9 +87,10 @@ std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q, const std::vector<int64_t> window_size, bool bottom_right_diagonal,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
...@@ -99,10 +100,10 @@ std::vector<py::object> fused_attn_fwd( ...@@ -99,10 +100,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......
...@@ -100,9 +100,10 @@ std::vector<py::object> fused_attn_fwd( ...@@ -100,9 +100,10 @@ std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q, const std::vector<int64_t> window_size, bool bottom_right_diagonal,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded, const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v, const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias, py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
...@@ -235,7 +236,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -235,7 +236,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(), softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
...@@ -295,7 +296,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -295,7 +296,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(), softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
...@@ -310,10 +311,10 @@ std::vector<py::object> fused_attn_fwd( ...@@ -310,10 +311,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd( std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic, NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO, const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const DType dqkv_type, const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors, const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer, const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
...@@ -532,14 +533,14 @@ std::vector<py::object> fused_attn_bwd( ...@@ -532,14 +533,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes // populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), nvte_fused_attn_bwd(
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
// allocate memory for workspace // allocate memory for workspace
...@@ -549,14 +550,14 @@ std::vector<py::object> fused_attn_bwd( ...@@ -549,14 +550,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel // execute kernel
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), nvte_fused_attn_bwd(
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
}); });
// destroy tensor wrappers // destroy tensor wrappers
......
...@@ -34,7 +34,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -34,7 +34,7 @@ from transformer_engine.pytorch.constants import (
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
...@@ -148,11 +148,21 @@ class TransformerLayer(torch.nn.Module): ...@@ -148,11 +148,21 @@ class TransformerLayer(torch.nn.Module):
distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
:attr:`window_size` in :meth:`forward` as well. :attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = "no_mask" default = "no_mask"
type of attention mask passed into softmax operation for decoder. type of attention mask passed into softmax operation for decoder.
enc_dec_window_size : Optional[Tuple[int, int]], default = None enc_dec_window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder. sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -301,7 +311,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -301,7 +311,9 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None, kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
enc_dec_attn_mask_type: str = "no_mask", enc_dec_attn_mask_type: str = "no_mask",
enc_dec_bottom_right_diagonal: Optional[bool] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
...@@ -343,8 +355,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -343,8 +355,10 @@ class TransformerLayer(torch.nn.Module):
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = enc_dec_window_size self.enc_dec_window_size = enc_dec_window_size
self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
...@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None, self_attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None, enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None,
enc_dec_bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module):
causal masks are aligned to the bottom right corner. causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder. Sliding window size for local attention in encoder.
bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
:attr:`layer_type` = ``"decoder"``. :attr:`layer_type` = ``"decoder"``.
...@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module):
Type of attention mask passed into softmax operation for decoder. Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = None enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder. Sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module): ...@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type = self.self_attn_mask_type self_attn_mask_type = self.self_attn_mask_type
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None: if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None: if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = dpa_utils.check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if enc_dec_bottom_right_diagonal is None:
enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
...@@ -778,6 +829,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -778,6 +829,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
...@@ -813,6 +865,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -813,6 +865,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type, attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size, window_size=enc_dec_window_size,
bottom_right_diagonal=enc_dec_bottom_right_diagonal,
encoder_output=encoder_output, encoder_output=encoder_output,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment