Commit 1c4bd4d8 authored by helloyongyang's avatar helloyongyang
Browse files

fix sage_attn2 cu_seqlens_kv

parent 83c12f2b
......@@ -22,7 +22,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None
if model_cls == "hunyuan":
x1 = sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_kv[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
x2 = sageattn(
......@@ -36,7 +36,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None
x = (
sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_kv[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
.transpose(2, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment