Commit 2f874771 authored by GoatWu's avatar GoatWu
Browse files

bug fixed

parent 429dcc45
...@@ -90,16 +90,8 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -90,16 +90,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
kv_end, kv_end,
) )
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end): def _infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
...@@ -120,7 +112,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -120,7 +112,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
self.kv_cache[block_idx]["k"][kv_start:kv_end] = k self.kv_cache[block_idx]["k"][kv_start:kv_end] = k
self.kv_cache[block_idx]["v"][kv_start:kv_end] = v self.kv_cache[block_idx]["v"][kv_start:kv_end] = v
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q=q, k=self.kv_cache[block_idx]["k"][:kv_end], k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q=q, k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device))
if not self.parallel_attention: if not self.parallel_attention:
attn_out = weights.self_attn_1.apply( attn_out = weights.self_attn_1.apply(
...@@ -129,8 +121,8 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -129,8 +121,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v=self.kv_cache[block_idx]["v"][:kv_end], v=self.kv_cache[block_idx]["v"][:kv_end],
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=q.size(0),
max_seqlen_kv=lk, max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
else: else:
...@@ -141,6 +133,9 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -141,6 +133,9 @@ class WanTransformerInferCausVid(WanTransformerInfer):
x = x + y * embed0[2].squeeze(0) x = x + y * embed0[2].squeeze(0)
return x
def _infer_cross_attn(self, weights, x, context, block_idx):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task == "i2v":
...@@ -159,7 +154,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -159,7 +154,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
k = self.crossattn_cache[block_idx]["k"] k = self.crossattn_cache[block_idx]["k"]
v = self.crossattn_cache[block_idx]["v"] v = self.crossattn_cache[block_idx]["v"]
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = weights.cross_attn_1.apply( attn_out = weights.cross_attn_1.apply(
q=q, q=q,
...@@ -167,8 +162,8 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -167,8 +162,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v=v, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=q.size(0),
max_seqlen_kv=lk, max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
...@@ -176,9 +171,8 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -176,9 +171,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len( cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q, q,
k_img,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
...@@ -188,8 +182,8 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -188,8 +182,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v=v_img, v=v_img,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=q.size(0),
max_seqlen_kv=lk, max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
...@@ -198,9 +192,52 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -198,9 +192,52 @@ class WanTransformerInferCausVid(WanTransformerInfer):
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out x = x + attn_out
return x
def _infer_ffn(self, weights, x, embed0):
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0)) y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y) y = weights.ffn_2.apply(y)
x = x + y * embed0[5].squeeze(0) x = x + y * embed0[5].squeeze(0)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[0],
grid_sizes, embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end
)
x = self._infer_cross_attn(
weights.compute_phases[1],
x,
context,
block_idx
)
x = self._infer_ffn(
weights.compute_phases[2],
x,
embed0
)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment