"...vision/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5f215b71a31b83fc39012eb5e467319b29872f55"
Unverified Commit 63c5e27e authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Do not drop mask with SDPA for more cases (#30311)

* overlooked

* style

* cleaner
parent acab997b
......@@ -319,8 +319,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
ignore_causal_mask = False
if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
if (
not is_tracing
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through
if len(attention_mask.shape) == 4:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment