Unverified Commit c3b3cd21 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

Fix cuDNN sliding window size (#1212)



* adjust window size to (i-window_size_left,i] for cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* reduce the window to make any errors more pronouced
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c24a4c41
......@@ -249,7 +249,7 @@ def test_dot_product_attention(
# Test backend availability
window_size = (-1, -1)
if swa:
window_size = tuple(torch.randint(0, config.max_seqlen_kv, [2], dtype=torch.int32).tolist())
window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
......
......@@ -75,9 +75,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
try {
......@@ -221,8 +218,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_options.set_sliding_window_length(window_size_left);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_sliding_window_length(window_size_left + 1);
}
sdpa_options.set_alibi_mask(is_alibi);
......@@ -407,9 +404,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
if (window_size_left == -1) {
window_size_left = s_q;
}
auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
......@@ -584,8 +578,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
if (cudnn_runtime_version >= 90200 && window_size_left != s_q) {
sdpa_backward_options.set_sliding_window_length(window_size_left);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_sliding_window_length(window_size_left + 1);
}
if (cudnn_runtime_version >= 90000) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment