Commit f804574f authored by GoatWu's avatar GoatWu
Browse files

bug fixed

parent 2f874771
......@@ -90,7 +90,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
kv_end,
)
return x
def _infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
......@@ -194,7 +194,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
x = x + attn_out
return x
def _infer_ffn(self, weights, x, embed0):
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
......@@ -204,7 +204,6 @@ class WanTransformerInferCausVid(WanTransformerInfer):
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
......@@ -214,30 +213,10 @@ class WanTransformerInferCausVid(WanTransformerInfer):
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[0],
grid_sizes, embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end
)
x = self._infer_self_attn(weights.compute_phases[0], grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end)
x = self._infer_cross_attn(
weights.compute_phases[1],
x,
context,
block_idx
)
x = self._infer_cross_attn(weights.compute_phases[1], x, context, block_idx)
x = self._infer_ffn(weights.compute_phases[2], x, embed0)
x = self._infer_ffn(
weights.compute_phases[2],
x,
embed0
)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment