Commit cefaf6cf authored by gushiqiao's avatar gushiqiao
Browse files

Fix

parent c962f4ce
...@@ -96,6 +96,22 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -96,6 +96,22 @@ class WanTransformerInfer(BaseTransformerInfer):
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
self.block_idx = block_idx
x = self.infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
audio_dit_blocks,
)
return x
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
self.block_idx = block_idx self.block_idx = block_idx
...@@ -271,37 +287,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -271,37 +287,6 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def zero_temporal_component_in_3DRoPE(self, valid_token_length, rotary_emb=None):
if rotary_emb is None:
return None
self.use_real = False
rope_t_dim = 44
if self.use_real:
freqs_cos, freqs_sin = rotary_emb
freqs_cos[valid_token_length:, :, :rope_t_dim] = 0
freqs_sin[valid_token_length:, :, :rope_t_dim] = 0
return freqs_cos, freqs_sin
else:
freqs_cis = rotary_emb
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
self.block_idx = block_idx
x = self.infer_block(
weights.blocks[block_idx],
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
audio_dit_blocks,
)
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
weights.compute_phases[0], weights.compute_phases[0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment