Commit d2fb49af authored by helloyongyang's avatar helloyongyang
Browse files

fix tea cache

parent cc04b3fb
...@@ -43,6 +43,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -43,6 +43,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.cutoff_steps = self.config.infer_steps * 2 - 2 self.cutoff_steps = self.config.infer_steps * 2 - 2
# calculate should_calc # calculate should_calc
@torch.no_grad()
def calculate_should_calc(self, embed, embed0): def calculate_should_calc(self, embed, embed0):
# 1. timestep embedding # 1. timestep embedding
modulated_inp = embed0 if self.use_ret_steps else embed modulated_inp = embed0 if self.use_ret_steps else embed
...@@ -95,30 +96,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -95,30 +96,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 3. return the judgement # 3. return the judgement
return should_calc return should_calc
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_main_blocks(self, weights, pre_infer_out):
if self.infer_conditional: if self.infer_conditional:
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records = self.scheduler.caching_records caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1: if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(embed, embed0) should_calc = self.calculate_should_calc(pre_infer_out.embed, pre_infer_out.embed0)
self.scheduler.caching_records[index] = should_calc self.scheduler.caching_records[index] = should_calc
if caching_records[index] or self.must_calc(index): if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, pre_infer_out)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(pre_infer_out.x)
else: else:
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2 caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1: if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(embed, embed0) should_calc = self.calculate_should_calc(pre_infer_out.embed, pre_infer_out.embed0)
self.scheduler.caching_records_2[index] = should_calc self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index] or self.must_calc(index): if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, pre_infer_out)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(pre_infer_out.x)
if self.config.enable_cfg: if self.config.enable_cfg:
self.switch_status() self.switch_status()
...@@ -131,19 +132,10 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -131,19 +132,10 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return x return x
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_calculating(self, weights, pre_infer_out):
ori_x = x.clone() ori_x = pre_infer_out.x.clone()
x = super().infer( x = super().infer_main_blocks(weights, pre_infer_out)
weights,
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.infer_conditional: if self.infer_conditional:
self.previous_residual_even = x - ori_x self.previous_residual_even = x - ori_x
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
......
...@@ -104,6 +104,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -104,6 +104,10 @@ class WanTransformerInfer(BaseTransformerInfer):
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_post_blocks(weights, x, pre_infer_out.embed)
def infer_main_blocks(self, weights, pre_infer_out):
x = self.infer_func( x = self.infer_func(
weights, weights,
pre_infer_out.grid_sizes, pre_infer_out.grid_sizes,
...@@ -115,9 +119,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -115,9 +119,9 @@ class WanTransformerInfer(BaseTransformerInfer):
pre_infer_out.context, pre_infer_out.context,
pre_infer_out.audio_dit_blocks, pre_infer_out.audio_dit_blocks,
) )
return self._infer_post_blocks(weights, x, pre_infer_out.embed) return x
def _infer_post_blocks(self, weights, x, e): def infer_post_blocks(self, weights, x, e):
if e.dim() == 2: if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment