Commit bc7c3e87 authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent f4213c00
......@@ -258,7 +258,7 @@ class WanTransformerInfer:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[1],
weights.compute_phases[0],
x,
shift_msa,
scale_msa,
......@@ -267,6 +267,6 @@ class WanTransformerInfer:
freqs,
seq_lens,
)
x = self._infer_cross_attn(weights.compute_phases[2], x, context)
x = self._infer_ffn(weights.compute_phases[3], x, c_shift_msa, c_scale_msa, c_gate_msa)
x = self._infer_cross_attn(weights.compute_phases[1], x, context)
x = self._infer_ffn(weights.compute_phases[2], x, c_shift_msa, c_scale_msa, c_gate_msa)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment