Unverified Commit c3b3cd21 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Fix cuDNN sliding window size (#1212)



* adjust window size to (i-window_size_left,i] for cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce the window to make any errors more pronouced
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c24a4c41
...@@ -249,7 +249,7 @@ def test_dot_product_attention( ...@@ -249,7 +249,7 @@ def test_dot_product_attention(
# Test backend availability # Test backend availability
window_size = (-1, -1) window_size = (-1, -1)
if swa: if swa:
window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist()) window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size) config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends( available_backends, fused_attn_backends = _get_attention_backends(
config, config,
......
...@@ -75,9 +75,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -75,9 +75,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) { if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
} }
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion(); auto cudnn_runtime_version = cudnnGetVersion();
try { try {
...@@ -221,8 +218,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -221,8 +218,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) { if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_sliding_window_length(window_size_left); sdpa_options.set_sliding_window_length(window_size_left + 1);
} }
sdpa_options.set_alibi_mask(is_alibi); sdpa_options.set_alibi_mask(is_alibi);
...@@ -407,9 +404,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -407,9 +404,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f); bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion(); auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id); const int sm_arch_ = cuda::sm_arch(device_id);
...@@ -584,8 +578,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -584,8 +578,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right) .set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale); .set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) { if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_sliding_window_length(window_size_left); sdpa_backward_options.set_sliding_window_length(window_size_left + 1);
} }
if (cudnn_runtime_version >= 90000) { if (cudnn_runtime_version >= 90000) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment