Commit f4b343f6 authored by helloyongyang's avatar helloyongyang
Browse files

update sage_attn2

parent 1c4bd4d8
...@@ -13,34 +13,28 @@ else: ...@@ -13,34 +13,28 @@ else:
def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"): def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"):
q, k, v = ( q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
q.transpose(1, 0).contiguous(),
k.transpose(1, 0).contiguous(),
v.transpose(1, 0).contiguous(),
)
if model_cls == "hunyuan": if model_cls == "hunyuan":
x1 = sageattn( x1 = sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0), q[: cu_seqlens_q[1]].unsqueeze(0),
k[:, : cu_seqlens_kv[1], :].unsqueeze(0), k[: cu_seqlens_kv[1]].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0), v[: cu_seqlens_kv[1]].unsqueeze(0),
tensor_layout="NHD",
) )
x2 = sageattn( x2 = sageattn(
q[:, cu_seqlens_q[1] :, :].unsqueeze(0), q[cu_seqlens_q[1] :].unsqueeze(0),
k[:, cu_seqlens_kv[1] :, :].unsqueeze(0), k[cu_seqlens_kv[1] :].unsqueeze(0),
v[:, cu_seqlens_kv[1] :, :].unsqueeze(0), v[cu_seqlens_kv[1] :].unsqueeze(0),
tensor_layout="NHD",
) )
x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous() x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls == "wan2.1": elif model_cls == "wan2.1":
x = ( x = sageattn(
sageattn( q.unsqueeze(0),
q[:, : cu_seqlens_q[1], :].unsqueeze(0), k.unsqueeze(0),
k[:, : cu_seqlens_kv[1], :].unsqueeze(0), v.unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0), tensor_layout="NHD",
)
.transpose(2, 1)
.contiguous()
) )
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment