Commit 066d7f19 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] support sage for wan-moe (#278)

parent a32f6801
...@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate): ...@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate):
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace"]: elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "seko_talk", "wan2.2", "wan2.1_vace", "wan2.2_moe"]:
x = sageattn( x = sageattn(
q.unsqueeze(0), q.unsqueeze(0),
k.unsqueeze(0), k.unsqueeze(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment