Commit f804574f authored by GoatWu's avatar GoatWu
Browse files

bug fixed

parent 2f874771
...@@ -204,7 +204,6 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -204,7 +204,6 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3: if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
...@@ -214,30 +213,10 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -214,30 +213,10 @@ class WanTransformerInferCausVid(WanTransformerInfer):
elif embed0.dim() == 2: elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1) embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn( x = self._infer_self_attn(weights.compute_phases[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
weights.compute_phases[0],
grid_sizes, embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end
)
x = self._infer_cross_attn( x = self._infer_cross_attn(weights.compute_phases[1], x, context, block_idx)
weights.compute_phases[1],
x,
context,
block_idx
)
x = self._infer_ffn( x = self._infer_ffn(weights.compute_phases[2], x, embed0)
weights.compute_phases[2],
x,
embed0
)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment