"vscode:/vscode.git/clone" did not exist on "069ff336a2ddd2432461e95276746420716c8432"
Commit f804574f authored by GoatWu's avatar GoatWu
Browse files

bug fixed

parent 2f874771
......@@ -204,7 +204,6 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
......@@ -214,30 +213,10 @@ class WanTransformerInferCausVid(WanTransformerInfer):
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[0],
grid_sizes, embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end
)
x = self._infer_self_attn(weights.compute_phases[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self._infer_cross_attn(
weights.compute_phases[1],
x,
context,
block_idx
)
x = self._infer_cross_attn(weights.compute_phases[1], x, context, block_idx)
x = self._infer_ffn(
weights.compute_phases[2],
x,
embed0
)
x = self._infer_ffn(weights.compute_phases[2], x, embed0)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment