Commit 2f874771 authored by GoatWu's avatar GoatWu
Browse files

bug fixed

parent 429dcc45
......@@ -90,16 +90,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
kv_end,
)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
def _infer_self_attn(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
......@@ -120,7 +112,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
self.kv_cache[block_idx]["k"][kv_start:kv_end] = k
self.kv_cache[block_idx]["v"][kv_start:kv_end] = v
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q=q, k=self.kv_cache[block_idx]["k"][:kv_end], k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q=q, k_lens=torch.tensor([kv_end], dtype=torch.int32, device=k.device))
if not self.parallel_attention:
attn_out = weights.self_attn_1.apply(
......@@ -129,8 +121,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v=self.kv_cache[block_idx]["v"][:kv_end],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
)
else:
......@@ -141,6 +133,9 @@ class WanTransformerInferCausVid(WanTransformerInfer):
x = x + y * embed0[2].squeeze(0)
return x
def _infer_cross_attn(self, weights, x, context, block_idx):
norm3_out = weights.norm3.apply(x)
if self.task == "i2v":
......@@ -159,7 +154,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
k = self.crossattn_cache[block_idx]["k"]
v = self.crossattn_cache[block_idx]["v"]
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = weights.cross_attn_1.apply(
q=q,
......@@ -167,8 +162,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
)
......@@ -176,9 +171,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_img,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
......@@ -188,8 +182,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
max_seqlen_q=q.size(0),
max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"],
)
......@@ -198,9 +192,52 @@ class WanTransformerInferCausVid(WanTransformerInfer):
attn_out = weights.cross_attn_o.apply(attn_out)
x = x + attn_out
return x
def _infer_ffn(self, weights, x, embed0):
norm2_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
y = weights.ffn_0.apply(norm2_out * (1 + embed0[4].squeeze(0)) + embed0[3].squeeze(0))
y = torch.nn.functional.gelu(y, approximate="tanh")
y = weights.ffn_2.apply(y)
x = x + y * embed0[5].squeeze(0)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, block_idx, kv_start, kv_end):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[0],
grid_sizes, embed,
x,
embed0,
seq_lens,
freqs,
context,
block_idx,
kv_start,
kv_end
)
x = self._infer_cross_attn(
weights.compute_phases[1],
x,
context,
block_idx
)
x = self._infer_ffn(
weights.compute_phases[2],
x,
embed0
)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment