Commit 8689e4c7 authored by helloyongyang's avatar helloyongyang
Browse files

update tea cache

parent 0b755a97
...@@ -54,7 +54,7 @@ class WanCausVidModel(WanModel): ...@@ -54,7 +54,7 @@ class WanCausVidModel(WanModel):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True, kv_start=kv_start, kv_end=kv_end) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, kv_start=kv_start, kv_end=kv_end)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......
...@@ -35,7 +35,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -35,7 +35,7 @@ class WanAudioPreInfer(WanPreInfer):
else: else:
self.sp_size = 1 self.sp_size = 1
def infer(self, weights, inputs, positive): def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents hidden_states = self.scheduler.latents
...@@ -71,7 +71,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -71,7 +71,7 @@ class WanAudioPreInfer(WanPreInfer):
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input)) audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
# audio_dit_blocks = None##Debug Drop Audio # audio_dit_blocks = None##Debug Drop Audio
if positive: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
......
...@@ -24,7 +24,6 @@ class WanTransformerInferCaching(WanTransformerInfer): ...@@ -24,7 +24,6 @@ class WanTransformerInferCaching(WanTransformerInfer):
class WanTransformerInferTeaCaching(WanTransformerInferCaching): class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0
self.teacache_thresh = config.teacache_thresh self.teacache_thresh = config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None self.previous_e0_even = None
...@@ -35,12 +34,12 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -35,12 +34,12 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.use_ret_steps = config.use_ret_steps self.use_ret_steps = config.use_ret_steps
if self.use_ret_steps: if self.use_ret_steps:
self.coefficients = self.config.coefficients[0] self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2 self.ret_steps = 5
self.cutoff_steps = self.config.infer_steps * 2 self.cutoff_steps = self.config.infer_steps
else: else:
self.coefficients = self.config.coefficients[1] self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2 self.ret_steps = 1
self.cutoff_steps = self.config.infer_steps * 2 - 2 self.cutoff_steps = self.config.infer_steps - 1
# calculate should_calc # calculate should_calc
@torch.no_grad() @torch.no_grad()
...@@ -50,8 +49,8 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -50,8 +49,8 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 2. L1 calculate # 2. L1 calculate
should_calc = False should_calc = False
if self.infer_conditional: if self.scheduler.infer_condition:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
else: else:
...@@ -67,7 +66,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -67,7 +66,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.previous_e0_even = self.previous_e0_even.cpu() self.previous_e0_even = self.previous_e0_even.cpu()
else: else:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance_odd = 0 self.accumulated_rel_l1_distance_odd = 0
else: else:
...@@ -97,7 +96,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -97,7 +96,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return should_calc return should_calc
def infer_main_blocks(self, weights, pre_infer_out): def infer_main_blocks(self, weights, pre_infer_out):
if self.infer_conditional: if self.scheduler.infer_condition:
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records = self.scheduler.caching_records caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1: if index <= self.scheduler.infer_steps - 1:
...@@ -121,11 +120,6 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -121,11 +120,6 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
else: else:
x = self.infer_using_cache(pre_infer_out.x) x = self.infer_using_cache(pre_infer_out.x)
if self.config.enable_cfg:
self.switch_status()
self.cnt += 1
if self.clean_cuda_cache: if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -136,7 +130,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -136,7 +130,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
ori_x = pre_infer_out.x.clone() ori_x = pre_infer_out.x.clone()
x = super().infer_main_blocks(weights, pre_infer_out) x = super().infer_main_blocks(weights, pre_infer_out)
if self.infer_conditional: if self.scheduler.infer_condition:
self.previous_residual_even = x - ori_x self.previous_residual_even = x - ori_x
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.previous_residual_even = self.previous_residual_even.cpu() self.previous_residual_even = self.previous_residual_even.cpu()
...@@ -153,7 +147,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -153,7 +147,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return x return x
def infer_using_cache(self, x): def infer_using_cache(self, x):
if self.infer_conditional: if self.scheduler.infer_condition:
x.add_(self.previous_residual_even.cuda()) x.add_(self.previous_residual_even.cuda())
else: else:
x.add_(self.previous_residual_odd.cuda()) x.add_(self.previous_residual_odd.cuda())
......
...@@ -33,7 +33,7 @@ class WanPreInfer: ...@@ -33,7 +33,7 @@ class WanPreInfer:
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0): def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents x = self.scheduler.latents
if self.scheduler.flag_df: if self.scheduler.flag_df:
...@@ -45,7 +45,7 @@ class WanPreInfer: ...@@ -45,7 +45,7 @@ class WanPreInfer:
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten() t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten()
if positive: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
......
...@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer):
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
self.infer_conditional = True
def switch_status(self):
self.infer_conditional = not self.infer_conditional
def _calculate_q_k_len(self, q, k_lens): def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
......
...@@ -329,9 +329,9 @@ class WanModel: ...@@ -329,9 +329,9 @@ class WanModel:
cfg_p_rank = dist.get_rank(cfg_p_group) cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0: if cfg_p_rank == 0:
noise_pred = self._infer_cond_uncond(inputs, positive=True) noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
else: else:
noise_pred = self._infer_cond_uncond(inputs, positive=False) noise_pred = self._infer_cond_uncond(inputs, infer_condition=False)
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)] noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group) dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
...@@ -339,13 +339,13 @@ class WanModel: ...@@ -339,13 +339,13 @@ class WanModel:
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1 noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
else: else:
# ==================== CFG Processing ==================== # ==================== CFG Processing ====================
noise_pred_cond = self._infer_cond_uncond(inputs, positive=True) noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, positive=False) noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
else: else:
# ==================== No CFG ==================== # ==================== No CFG ====================
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, positive=True) self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
...@@ -355,8 +355,10 @@ class WanModel: ...@@ -355,8 +355,10 @@ class WanModel:
self.transformer_weights.post_weights_to_cpu() self.transformer_weights.post_weights_to_cpu()
@torch.no_grad() @torch.no_grad()
def _infer_cond_uncond(self, inputs, positive=True): def _infer_cond_uncond(self, inputs, infer_condition=True):
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=positive) self.scheduler.infer_condition = infer_condition
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)
......
...@@ -10,6 +10,7 @@ class BaseScheduler: ...@@ -10,6 +10,7 @@ class BaseScheduler:
self.caching_records = [True] * config.infer_steps self.caching_records = [True] * config.infer_steps
self.flag_df = False self.flag_df = False
self.transformer_infer = None self.transformer_infer = None
self.infer_condition = True # cfg status
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment