Commit 2fb721ab authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #48 from ModelTC/dev_fixbugs

Dev fixbugs
parents f4213c00 4fd83968
......@@ -37,10 +37,10 @@ class WanTransformerInfer:
return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, x, embed0, seq_lens, freqs, context)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
def _infer_with_offload(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
......@@ -63,7 +63,7 @@ class WanTransformerInfer:
return x
def _infer_with_phases_offload(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(weights.blocks_num):
weights.blocks[block_idx].modulation.to_cuda()
......@@ -114,7 +114,7 @@ class WanTransformerInfer:
return x
def _infer_without_offload(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(self.blocks_num):
x = self.infer_block(
weights.blocks[block_idx],
......@@ -249,7 +249,7 @@ class WanTransformerInfer:
x.add_(y * c_gate_msa.squeeze(0))
return x
def infer_block(self, weights, grid_sizes, x, embed0, seq_lens, freqs, context):
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1)
......@@ -258,7 +258,7 @@ class WanTransformerInfer:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
x = self._infer_self_attn(
weights.compute_phases[1],
weights.compute_phases[0],
x,
shift_msa,
scale_msa,
......@@ -267,6 +267,6 @@ class WanTransformerInfer:
freqs,
seq_lens,
)
x = self._infer_cross_attn(weights.compute_phases[2], x, context)
x = self._infer_ffn(weights.compute_phases[3], x, c_shift_msa, c_scale_msa, c_gate_msa)
x = self._infer_cross_attn(weights.compute_phases[1], x, context)
x = self._infer_ffn(weights.compute_phases[2], x, c_shift_msa, c_scale_msa, c_gate_msa)
return x
......@@ -183,7 +183,7 @@ class WanModel:
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, *pre_infer_out)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
......@@ -194,7 +194,7 @@ class WanModel:
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, *pre_infer_out)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
if self.config["feature_caching"] == "Tea":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment