"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b6c9f47fd6f911450024c52e382e544e5d04387a"
Unverified Commit 63c5e27e authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Do not drop mask with SDPA for more cases (#30311)

* overlooked

* style

* cleaner
parent acab997b
...@@ -319,8 +319,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( ...@@ -319,8 +319,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
ignore_causal_mask = False ignore_causal_mask = False
if attention_mask is None: if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window: if (
ignore_causal_mask = not is_tracing not is_tracing
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window: elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through # 4d mask is passed through
if len(attention_mask.shape) == 4: if len(attention_mask.shape) == 4:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment