Commit bc7c3e87 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent f4213c00
...@@ -258,7 +258,7 @@ class WanTransformerInfer: ...@@ -258,7 +258,7 @@ class WanTransformerInfer:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn( x = self._infer_self_attn(
weights.compute_phases[1], weights.compute_phases[0],
x, x,
shift_msa, shift_msa,
scale_msa, scale_msa,
...@@ -267,6 +267,6 @@ class WanTransformerInfer: ...@@ -267,6 +267,6 @@ class WanTransformerInfer:
freqs, freqs,
seq_lens, seq_lens,
) )
x = self._infer_cross_attn(weights.compute_phases[2], x, context) x = self._infer_cross_attn(weights.compute_phases[1], x, context)
x = self._infer_ffn(weights.compute_phases[3], x, c_shift_msa, c_scale_msa, c_gate_msa) x = self._infer_ffn(weights.compute_phases[2], x, c_shift_msa, c_scale_msa, c_gate_msa)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment