Unverified Commit c6a92a4d authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

Add support for SWA (left, right) with FusedAttention (#2477)

* SWA (left, right) with FusedAttention changes cherry-picked from https://github.com/NVIDIA/TransformerEngine/pull/1369

Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix test_kv_cache failures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove unnecessary comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix some more filter issues, address feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix for local test case failures - `bottom_right_diagonal` should be calculated in `fused_attn_fwd` call as well
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make conditions more accurate
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add cp tests to test swa (left, right)
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove dead code and make conditions better
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feedback form Charlene
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small er
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* plumb `bottom_right_diagonal` through jax
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* plumb `bottom_right_diagonal` through jax
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing fields
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* use proper mask type in CP
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0f0e229b
......@@ -153,6 +153,7 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
......@@ -171,6 +172,7 @@ def test_dot_product_attention(
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = get_available_attention_backends(
......@@ -701,9 +703,10 @@ model_configs_swa = {
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "sbhd_sbhd_sbhd"])
def test_dpa_sliding_window(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
model_configs_alibi_slopes = {
......
......@@ -147,7 +147,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"
), # MHA
"cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
"cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 512)), # MHA
"cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA
"cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA
"cp_2_2": ModelConfig(
......@@ -163,7 +163,7 @@ model_configs_fused_attn = {
2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias"
), # GQA
"cp_2_4": ModelConfig(
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0)
2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 512)
), # GQA
"cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA
"cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA
......@@ -187,7 +187,16 @@ dtypes = ["bf16", "fp16", "fp8"]
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = [
"cp_1_0",
"cp_1_1",
"cp_1_4",
"cp_2_0",
"cp_2_2",
"cp_2_4",
"cp_3_2",
"cp_4_2",
]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
......
......@@ -353,11 +353,11 @@ def get_available_attention_backends(
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
if AttentionLogging._is_logging_setup is False:
AttentionLogging.setup_logging()
with logging_context(highest_level=AttentionLogging._log_level):
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
for i in range(3):
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
return available_backends, flash_attention_backend, fused_attn_backends
......@@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(window_size_right == -1 || window_size_right == 0)) ||
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
((window_size_left == -1 && window_size_right == -1 &&
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) ||
((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0 &&
......@@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd}
(cudnn_runtime_version >= 90600 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
((window_size_left >= 0 || window_size_left == -1) &&
(window_size_right >= 0 || window_size_right == -1) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700)))) ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
......@@ -515,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead.
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
......@@ -598,10 +600,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state,
wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
......@@ -639,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
......@@ -736,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view,
&K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view,
&dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -790,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -902,10 +906,10 @@ void nvte_fused_attn_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. "
......@@ -945,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
cudaStream_t stream) {
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -1052,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S,
output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO,
input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -1106,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -1195,10 +1199,10 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
......@@ -1228,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -1302,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO,
input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
......
......@@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2,
void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1,
void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
bottom_right_diagonal = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f);
......@@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
softmax_type,
window_size_left,
window_size_right,
bottom_right_diagonal,
true,
tensorType,
cudnn_frontend::DataType_t::NOT_SET,
......@@ -254,9 +256,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
fe::DiagonalAlignment_t const &diagonal_alignment =
bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT
: fe::DiagonalAlignment_t::TOP_LEFT;
sdpa_options.set_diagonal_alignment(diagonal_alignment);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
sdpa_options.set_diagonal_band_right_bound(window_size_right);
}
sdpa_options.set_alibi_mask(is_alibi);
......@@ -542,13 +551,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -563,6 +573,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
bottom_right_diagonal = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f);
......@@ -621,6 +632,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
softmax_type,
window_size_left,
window_size_right,
bottom_right_diagonal,
deterministic,
tensorType,
cudnn_frontend::DataType_t::NOT_SET,
......@@ -781,9 +793,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}
fe::DiagonalAlignment_t const &diagonal_alignment =
bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT
: fe::DiagonalAlignment_t::TOP_LEFT;
sdpa_backward_options.set_diagonal_alignment(diagonal_alignment);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
sdpa_backward_options.set_diagonal_band_right_bound(window_size_right);
}
if (cudnn_runtime_version >= 90000) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
......@@ -1044,8 +1064,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
......@@ -1180,11 +1200,11 @@ void fused_attn_arbitrary_seqlen_fwd(
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV,
devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1206,13 +1226,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
......@@ -1273,8 +1294,8 @@ void fused_attn_arbitrary_seqlen_bwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
......
......@@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
......@@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1(
0,
0,
true,
true,
qkv_tensor_type,
o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET,
......@@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1(
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
false,
qkv_tensor_type,
o_tensor_type,
......
......@@ -110,6 +110,7 @@ struct FADescriptor_v1 {
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool bottom_right_diagonal;
bool deterministic;
cudnn_frontend::DataType_t qkv_tensor_type;
cudnn_frontend::DataType_t o_tensor_type;
......@@ -121,15 +122,16 @@ struct FADescriptor_v1 {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
window_size_left, window_size_right, bottom_right_diagonal, deterministic,
bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type,
generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal,
rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type,
rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
}
};
......
......@@ -270,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -333,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
......@@ -347,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
......@@ -410,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -425,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -479,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
......@@ -495,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
cudaStream_t stream);
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -560,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -629,6 +636,7 @@ void nvte_fused_attn_fwd(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
......@@ -644,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
......
......@@ -70,6 +70,7 @@ __all__ = [
"is_training",
"max_segments_per_seq",
"window_size",
"bottom_right_diagonal",
"context_parallel_load_balanced",
"cp_axis",
"cp_striped_window_size",
......@@ -91,6 +92,7 @@ class _FusedAttnConfig:
is_training: bool
max_segments_per_seq: int
window_size: Tuple[int, int]
bottom_right_diagonal: bool
context_parallel_load_balanced: bool
cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring P2P + THD + SWA
......@@ -371,6 +373,11 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
bottom_right_diagonal = config.attn_mask_type in [
AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK,
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK,
]
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, batch_shape)
......@@ -395,6 +402,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
bottom_right_diagonal,
)
wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
......@@ -503,6 +511,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right_diagonal=config.bottom_right_diagonal,
softmax_type=int(config.softmax_type.value),
)
......@@ -813,6 +822,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
config.bottom_right_diagonal,
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
......@@ -948,6 +958,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left,
window_size_right=window_size_right,
bottom_right_diagonal=config.bottom_right_diagonal,
softmax_type=int(config.softmax_type.value),
)
......@@ -1357,9 +1368,10 @@ class _FusedAttnCPWithAllGatherHelper:
def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
adjusted_mask = self.get_adjusted_mask()
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
attn_mask_type=adjusted_mask,
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
......@@ -1367,6 +1379,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
bottom_right_diagonal=adjusted_mask.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
......@@ -1375,9 +1388,10 @@ class _FusedAttnCPWithAllGatherHelper:
def get_step_config_for_striped(self, max_seqlen, cp_size) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call (made via a striped AG primitive) to fused attention."""
adjusted_mask = self.get_adjusted_mask()
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
attn_mask_type=adjusted_mask,
softmax_type=self.config.softmax_type,
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
......@@ -1385,6 +1399,7 @@ class _FusedAttnCPWithAllGatherHelper:
is_training=self.config.is_training,
max_segments_per_seq=self.get_adjusted_max_segments_per_seq(max_seqlen, cp_size),
window_size=self.config.window_size,
bottom_right_diagonal=adjusted_mask.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
......@@ -2430,6 +2445,7 @@ class _FusedAttnCPWithP2PHelper:
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
......@@ -3418,6 +3434,7 @@ def fused_attn_fwd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
......@@ -3590,6 +3607,7 @@ def fused_attn_bwd(
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
......
......@@ -121,7 +121,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool bottom_right_diagonal);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
......@@ -129,7 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
......
......@@ -144,7 +144,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool bottom_right_diagonal) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
......@@ -192,7 +192,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
window_size_left, window_size_right, bottom_right_diagonal, query_workspace_tensor.data(),
nullptr);
}
nvte_tensor_pack_destroy(&aux_output_tensors);
......@@ -237,7 +238,7 @@ static void FusedAttnForwardImpl(
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
......@@ -328,7 +329,7 @@ static void FusedAttnForwardImpl(
k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, workspace_tensor.data(), stream);
window_size_left, window_size_right, bottom_right_diagonal, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_output_tensors);
}
......@@ -346,6 +347,7 @@ static void FusedAttnForwardImpl(
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
auto window_size_left = get_attr_value<int64_t>(attrs, "window_size_left"); \
auto window_size_right = get_attr_value<int64_t>(attrs, "window_size_right"); \
bool bottom_right_diagonal = get_attr_value<bool>(attrs, "bottom_right_diagonal"); \
float scaling_factor = get_attr_value<double>(attrs, "scaling_factor"); \
float dropout_probability = get_attr_value<double>(attrs, "dropout_probability"); \
NVTE_Bias_Type bias_type = \
......@@ -384,7 +386,7 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
qk_head_dim, v_head_dim, max_segments_per_seq, wkspace_size, scaling_factor,
dropout_probability, bias_type, mask_type, softmax_type, qkv_layout, dtype, wkspace_dtype,
is_training, deterministic, window_size_left, window_size_right);
is_training, deterministic, window_size_left, window_size_right, bottom_right_diagonal);
return ffi_with_cuda_error_check();
}
......@@ -415,7 +417,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq,
int64_t window_size_left, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto q_tensor = TensorWrapper(nullptr, q_shape, dtype);
auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype);
......@@ -467,17 +469,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, bottom_right_diagonal, deterministic, false,
query_workspace_tensor.data(), nullptr);
}
nvte_tensor_pack_destroy(&aux_input_tensors);
......@@ -496,7 +499,7 @@ static void FusedAttnBackwardImpl(
size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, NVTE_QKV_Layout qkv_layout,
DType dtype, DType wkspace_dtype, bool is_training, bool deterministic,
int64_t window_size_left, int64_t window_size_right) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal) {
FUSED_ATTN_IMPL_COMMON_BLOCK;
/* Input tensors */
......@@ -593,16 +596,17 @@ static void FusedAttnBackwardImpl(
}
}
nvte_fused_attn_bwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(),
dsoftmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream);
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), dsoftmax_offset_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, false, workspace_tensor.data(), stream);
nvte_tensor_pack_destroy(&aux_input_tensors);
}
......@@ -631,7 +635,7 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T
q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, qk_head_dim, v_head_dim,
max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, mask_type,
softmax_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left,
window_size_right);
window_size_right, bottom_right_diagonal);
return ffi_with_cuda_error_check();
}
......
......@@ -261,6 +261,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
......@@ -346,6 +347,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
attention_mask=attention_mask,
window_size=window_size,
attention_type=self.attention_type,
bottom_right_alignment=(
attn_mask_type not in ["causal", "padding_causal"]
if bottom_right_diagonal is None
else bottom_right_diagonal
),
)
)
......@@ -449,7 +455,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
alibi_slopes=alibi_slopes,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
bottom_right_alignment=(
attn_mask_type not in ["causal", "padding_causal"]
if bottom_right_diagonal is None
else bottom_right_diagonal
),
)
matmul_result = torch.baddbmm(
matmul_result,
......@@ -1110,6 +1120,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type,
softmax_type,
window_size,
bottom_right_diagonal,
rng_gen,
fused_attention_backend,
use_FAv2_bwd,
......@@ -1213,6 +1224,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type,
softmax_type,
window_size,
bottom_right_diagonal,
rng_gen,
softmax_offset,
cuda_graph=is_graph_capturing(),
......@@ -1290,6 +1302,7 @@ class FusedAttnFunc(torch.autograd.Function):
attn_mask_type,
softmax_type,
window_size,
bottom_right_diagonal,
rng_gen,
softmax_offset,
return_max_logit,
......@@ -1377,6 +1390,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type
ctx.window_size = window_size
ctx.bottom_right_diagonal = bottom_right_diagonal
ctx.fused_attention_backend = (
fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"]
)
......@@ -1527,6 +1541,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.bottom_right_diagonal,
ctx.deterministic,
is_graph_capturing(),
)
......@@ -1592,6 +1607,7 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_mask_type,
ctx.softmax_type,
ctx.window_size,
ctx.bottom_right_diagonal,
ctx.deterministic,
is_graph_capturing(),
)
......@@ -1631,6 +1647,7 @@ class FusedAttnFunc(torch.autograd.Function):
None,
None,
None,
None,
d_softmax_offset,
None,
None,
......@@ -1728,6 +1745,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
fused_attention_backend: tex.NVTE_Fused_Attn_Backend = tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -1935,6 +1953,7 @@ class FusedAttention(torch.nn.Module):
attn_mask_type,
self.softmax_type,
window_size,
bottom_right_diagonal,
None, # rng_gen
fused_attention_backend,
use_FAv2_bwd,
......
......@@ -228,6 +228,11 @@ class DotProductAttention(TransformerEngineBaseModule):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in ``forward`` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
attention_type : str, default = "self"
type of attention, either ``"self"`` and ``"cross"``.
layer_number : int, default = None
......@@ -324,6 +329,7 @@ class DotProductAttention(TransformerEngineBaseModule):
qkv_format: str = "sbhd",
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
......@@ -350,6 +356,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type = "padding_causal"
self.attn_mask_type = attn_mask_type
self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
self.bottom_right_diagonal = bottom_right_diagonal
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -811,6 +818,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv: int = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -963,6 +971,16 @@ class DotProductAttention(TransformerEngineBaseModule):
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = None
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
Note: This parameter will be automatically overridden based on the
`attn_mask_type` - it will be forced to `False` for 'causal' and
'padding_causal' mask types, and forced to `True` for mask types
containing 'bottom_right' (e.g., 'causal_bottom_right',
'padding_causal_bottom_right'), regardless of the explicitly passed value.
checkpoint_core_attention : bool, default = False
If true, forward activations for attention are recomputed
during the backward pass in order to save memory that would
......@@ -1081,6 +1099,15 @@ class DotProductAttention(TransformerEngineBaseModule):
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
# checks for qkv_format
if qkv_format is None:
......@@ -1144,6 +1171,8 @@ class DotProductAttention(TransformerEngineBaseModule):
assert "padding" in attn_mask_type, "KV caching requires padding mask!"
if attn_mask_type == "padding_causal":
attn_mask_type = attn_mask_type + "_bottom_right"
# since attention mask is changed, set `bottom_right_diagonal` to True
bottom_right_diagonal = True
if self.attention_type != "cross":
self.fast_setattr("attention_type", "cross")
......@@ -1257,7 +1286,6 @@ class DotProductAttention(TransformerEngineBaseModule):
if self.layer_number == 1:
_alibi_cache["_alibi_slopes_require_update"] = True
_alibi_cache["_alibi_bias_require_update"] = True
bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
if core_attention_bias_type == "alibi":
assert (
core_attention_bias is None
......@@ -1266,7 +1294,7 @@ class DotProductAttention(TransformerEngineBaseModule):
_alibi_cache["_num_heads"] != query_layer.shape[-2]
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal
or _alibi_cache["_alibi_slopes"] is None
):
_alibi_cache["_alibi_slopes_require_update"] = True
......@@ -1323,6 +1351,7 @@ class DotProductAttention(TransformerEngineBaseModule):
head_dim_v=head_dim_v,
attn_mask_type=attn_mask_type,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
......@@ -1446,9 +1475,7 @@ class DotProductAttention(TransformerEngineBaseModule):
if use_fused_attention:
fu_core_attention_bias_type = core_attention_bias_type
fu_core_attention_bias = core_attention_bias
if core_attention_bias_type == "alibi" and (
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
):
if core_attention_bias_type == "alibi" and (alibi_slopes is not None):
fu_core_attention_bias_type = "post_scale_bias"
_, fu_core_attention_bias = dpa_utils.get_alibi(
_alibi_cache,
......@@ -1457,7 +1484,7 @@ class DotProductAttention(TransformerEngineBaseModule):
max_seqlen_kv,
alibi_slopes=alibi_slopes,
bias_dtype=query_layer.dtype,
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
bottom_right_alignment=bottom_right_diagonal,
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
......@@ -1475,6 +1502,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
......@@ -1505,6 +1533,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
fused_attention_backend=fused_attention_backend,
core_attention_bias_type=fu_core_attention_bias_type,
core_attention_bias=fu_core_attention_bias,
......@@ -1539,6 +1568,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
......@@ -1562,6 +1592,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
......
......@@ -200,6 +200,9 @@ class AttentionParams:
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
window_size : Tuple[int, int], default = None
Sliding window attention size.
bottom_right_diagonal: bool, default = `None`
Whether to align sliding window and ALiBi diagonal to the bottom right corner
of the softmax matrix.
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`.
core_attention_bias_type : str, default = no_bias
......@@ -249,6 +252,7 @@ class AttentionParams:
head_dim_v: int = 64
attn_mask_type: str = "no_mask"
window_size: Union[Tuple[int, int], None] = None
bottom_right_diagonal: bool = True
alibi_slopes_shape: Union[torch.Size, List, None] = None
core_attention_bias_type: str = "no_bias"
core_attention_bias_shape: str = "1hss"
......@@ -325,6 +329,7 @@ def get_attention_backend(
head_dim_v = attention_params.head_dim_v
attn_mask_type = attention_params.attn_mask_type
window_size = attention_params.window_size
bottom_right_diagonal = attention_params.bottom_right_diagonal
alibi_slopes_shape = attention_params.alibi_slopes_shape
core_attention_bias_type = attention_params.core_attention_bias_type
core_attention_bias_shape = attention_params.core_attention_bias_shape
......@@ -859,39 +864,43 @@ def get_attention_backend(
# backend | window_size | diagonal alignment
# ---------------------------------------------------------------------------------
# FlashAttention | (-1, -1) or (>=0, >=0) | bottom right
# FusedAttention | (-1, 0) or (>=0, 0) | top left
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both;
# FusedAttention | (-1, 0) or (>=0, >=0) | top left, bottom right
# UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | top left, bottom right
# | | converts window_size to an 'arbitrary' mask
if window_size is None:
window_size = check_set_window_size(attn_mask_type, window_size)
else:
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention"
" for FP8"
)
use_fused_attention = False
elif window_size[1] != 0 or attention_dropout != 0.0:
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"with (left, 0) and no dropout"
)
use_fused_attention = False
elif max_seqlen_q > max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with s_q > s_kv for cross-attention"
)
use_fused_attention = False
if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.3")
elif not FlashAttentionUtils.v2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention_2 = False
if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha):
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention for FP8"
)
use_fused_attention = False
elif attention_dropout != 0.0:
logger.debug(
"Disabling FusedAttention as it only supports sliding window attention "
"without dropout"
)
use_fused_attention = False
elif max_seqlen_q > max_seqlen_kv:
logger.debug(
"Disabling FusedAttention as it does not support sliding window attention "
"with s_q > s_kv for cross-attention"
)
use_fused_attention = False
if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]):
if not FlashAttentionUtils.is_installed:
FlashAttentionUtils.version_required = PkgVersion("2.3")
elif not FlashAttentionUtils.v2_3_plus:
logger.debug(
"Disabling FlashAttention as sliding window attention requires flash-attn 2.3+"
)
use_flash_attention_2 = False
elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
# Filter: Attention bias
# backend | bias types | ALiBi diagonal alignment
......@@ -913,6 +922,12 @@ def get_attention_backend(
elif not FlashAttentionUtils.v2_4_plus:
logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+")
use_flash_attention_2 = False
elif not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv:
logger.debug(
"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
if (
core_attention_bias_type not in ["no_bias", "alibi"]
......@@ -930,13 +945,12 @@ def get_attention_backend(
if (
use_fused_attention
and core_attention_bias_type == "alibi"
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
and (alibi_slopes_shape is not None)
):
fu_core_attention_bias_type = "post_scale_bias"
fu_core_attention_bias_requires_grad = False
if alibi_slopes_shape is None:
fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
fu_core_attention_bias_shape = "1hss"
elif (
len(alibi_slopes_shape) == 2
......
......@@ -31,6 +31,7 @@ from transformer_engine.pytorch.distributed import (
from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb
from transformer_engine.pytorch.attention.dot_product_attention import utils as dpa_utils
from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled
......@@ -92,6 +93,11 @@ class MultiheadAttention(torch.nn.Module):
map to ``window_size = (-1, 0)`` and Transformer Engine distinguishes them based on
``attn_mask_type``. Similar to :attr:`attn_mask_type`, ``window_size`` can
be overridden by :attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
num_gqa_groups : int, default = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -247,6 +253,7 @@ class MultiheadAttention(torch.nn.Module):
layer_number: Optional[int] = None,
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
......@@ -285,6 +292,7 @@ class MultiheadAttention(torch.nn.Module):
self.qkv_format = qkv_format
self.attn_mask_type = attn_mask_type
self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.layer_number = 1 if layer_number is None else layer_number
self.input_layernorm = input_layernorm
self.attention_type = attention_type
......@@ -621,6 +629,7 @@ class MultiheadAttention(torch.nn.Module):
encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -667,6 +676,11 @@ class MultiheadAttention(torch.nn.Module):
aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
sliding window size for local attention.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
``layer_type="decoder"``.
......@@ -731,6 +745,17 @@ class MultiheadAttention(torch.nn.Module):
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if "padding" in attn_mask_type and attention_mask is not None:
for mask in attention_mask:
assert mask.dtype == torch.bool, "Attention mask must be in boolean type!"
......@@ -1001,6 +1026,7 @@ class MultiheadAttention(torch.nn.Module):
attention_mask=attention_mask,
attn_mask_type=attn_mask_type,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
checkpoint_core_attention=checkpoint_core_attention,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
......
......@@ -137,6 +137,7 @@ def fused_attn_fwd(
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
bottom_right_diagonal: bool = None,
rng_gen: torch.Generator = None,
softmax_offset: torch.Tensor = None,
return_max_logit: bool = False,
......@@ -212,6 +213,9 @@ def fused_attn_fwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
rng_gen : torch.Generator, default = None
random number generator;
if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen
......@@ -255,6 +259,12 @@ def fused_attn_fwd(
max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None
"""
if bottom_right_diagonal is None:
bottom_right_diagonal = attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
......@@ -306,6 +316,7 @@ def fused_attn_fwd(
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
bottom_right_diagonal,
cu_seqlens_q,
cu_seqlens_kv,
q,
......@@ -370,6 +381,7 @@ def fused_attn_bwd(
attn_mask_type: str = "padding",
softmax_type: str = "vanilla",
window_size: Tuple[int, int] = (-1, -1),
bottom_right_diagonal: bool = None,
deterministic: bool = False,
cuda_graph: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]:
......@@ -442,6 +454,9 @@ def fused_attn_bwd(
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = None
whether to align sliding window and ALiBi diagonal to the top left (False) or
bottom right (True) corner of the softmax matrix.
deterministic : bool, default = False
whether to execute the backward pass with deterministic behaviours.
cuda_graph : bool, default = False
......@@ -462,6 +477,12 @@ def fused_attn_bwd(
gradient tensor of softmax offset of shape [1, h_q, 1, 1].
See softmax_type in DotProductAttention for details.
"""
if bottom_right_diagonal is None:
bottom_right_diagonal = attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}
if attn_scale is None:
d = q.size(-1)
attn_scale = 1.0 / math.sqrt(d)
......@@ -500,6 +521,7 @@ def fused_attn_bwd(
AttnMaskType[attn_mask_type],
SoftmaxType[softmax_type],
window_size,
bottom_right_diagonal,
deterministic,
cu_seqlens_q,
cu_seqlens_kv,
......
......@@ -87,9 +87,10 @@ std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::vector<int64_t> window_size, bool bottom_right_diagonal,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
......@@ -99,10 +100,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size,
bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......
......@@ -100,9 +100,10 @@ std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
const std::vector<int64_t> window_size, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const at::ScalarType fake_dtype, const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::vector<int64_t> window_size, bool bottom_right_diagonal,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded,
const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
......@@ -235,7 +236,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
......@@ -295,7 +296,7 @@ std::vector<py::object> fused_attn_fwd(
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(),
te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training,
return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], workspace.data(),
softmax_type, window_size[0], window_size[1], bottom_right_diagonal, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
......@@ -310,10 +311,10 @@ std::vector<py::object> fused_attn_fwd(
std::vector<py::object> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size, bool deterministic,
const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
const py::handle K, const py::handle V, const py::handle O, const py::handle dO,
const at::ScalarType fake_dtype, const DType dqkv_type,
NVTE_Softmax_Type softmax_type, const std::vector<int64_t> window_size,
bool bottom_right_diagonal, bool deterministic, const at::Tensor cu_seqlens_q,
const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
const std::vector<at::Tensor> Aux_CTX_Tensors,
const std::optional<at::Tensor> cu_seqlens_q_padded,
const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
......@@ -532,14 +533,14 @@ std::vector<py::object> fused_attn_bwd(
// populate tensors with appropriate shapes and dtypes
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// allocate memory for workspace
......@@ -549,14 +550,14 @@ std::vector<py::object> fused_attn_bwd(
// execute kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(),
te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(),
te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(),
te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q,
max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type,
softmax_type, window_size[0], window_size[1], deterministic, cuda_graph,
workspace.data(), at::cuda::getCurrentCUDAStream());
nvte_fused_attn_bwd(
te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(),
&nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(),
te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(),
te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv,
attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0],
window_size[1], bottom_right_diagonal, deterministic, cuda_graph, workspace.data(),
at::cuda::getCurrentCUDAStream());
});
// destroy tensor wrappers
......
......@@ -34,7 +34,7 @@ from transformer_engine.pytorch.constants import (
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......@@ -148,11 +148,21 @@ class TransformerLayer(torch.nn.Module):
distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
:attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = "no_mask"
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
zero_centered_gamma : bool, default = False
if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -301,7 +311,9 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_bottom_right_diagonal: Optional[bool] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
......@@ -343,8 +355,10 @@ class TransformerLayer(torch.nn.Module):
self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = enc_dec_window_size
self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
......@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
enc_dec_bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module):
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder.
bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
:attr:`layer_type` = ``"decoder"``.
......@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module):
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = dpa_utils.check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if enc_dec_bottom_right_diagonal is None:
enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
assert (
self_attn_mask_type in AttnMaskTypes
......@@ -778,6 +829,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......@@ -813,6 +865,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size,
bottom_right_diagonal=enc_dec_bottom_right_diagonal,
encoder_output=encoder_output,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment