Unverified Commit 6d4f9d3a authored by haosdent's avatar haosdent Committed by GitHub
Browse files

[Bugfix] Fix DCP + FA3 crash due to missing num_splits in _forward_with_dcp (#35082)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
parent fbe3f012
......@@ -847,6 +847,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
)
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
......@@ -876,6 +877,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
)
assert context_attn_out_cor.shape == query_attn_out.shape
assert context_lse_cor.shape == query_lse.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment