Commit 0fcc5842 authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents c550409c 8689e4c7
......@@ -54,7 +54,7 @@ class WanCausVidModel(WanModel):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True, kv_start=kv_start, kv_end=kv_end)
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, kv_start=kv_start, kv_end=kv_end)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
......
......@@ -35,7 +35,7 @@ class WanAudioPreInfer(WanPreInfer):
else:
self.sp_size = 1
def infer(self, weights, inputs, positive):
def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents
......@@ -71,7 +71,7 @@ class WanAudioPreInfer(WanPreInfer):
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
# audio_dit_blocks = None##Debug Drop Audio
if positive:
if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
......
......@@ -24,7 +24,6 @@ class WanTransformerInferCaching(WanTransformerInfer):
class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.teacache_thresh = config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None
......@@ -35,12 +34,12 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.use_ret_steps = config.use_ret_steps
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
self.ret_steps = 5
self.cutoff_steps = self.config.infer_steps
else:
self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
self.ret_steps = 1
self.cutoff_steps = self.config.infer_steps - 1
# calculate should_calc
@torch.no_grad()
......@@ -50,8 +49,8 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 2. L1 calculate
should_calc = False
if self.infer_conditional:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
if self.scheduler.infer_condition:
if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps:
should_calc = True
self.accumulated_rel_l1_distance_even = 0
else:
......@@ -67,7 +66,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.previous_e0_even = self.previous_e0_even.cpu()
else:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
else:
......@@ -97,7 +96,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return should_calc
def infer_main_blocks(self, weights, pre_infer_out):
if self.infer_conditional:
if self.scheduler.infer_condition:
index = self.scheduler.step_index
caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1:
......@@ -121,11 +120,6 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
else:
x = self.infer_using_cache(pre_infer_out.x)
if self.config.enable_cfg:
self.switch_status()
self.cnt += 1
if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
......@@ -136,7 +130,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
ori_x = pre_infer_out.x.clone()
x = super().infer_main_blocks(weights, pre_infer_out)
if self.infer_conditional:
if self.scheduler.infer_condition:
self.previous_residual_even = x - ori_x
if self.config["cpu_offload"]:
self.previous_residual_even = self.previous_residual_even.cpu()
......@@ -153,7 +147,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return x
def infer_using_cache(self, x):
if self.infer_conditional:
if self.scheduler.infer_condition:
x.add_(self.previous_residual_even.cuda())
else:
x.add_(self.previous_residual_odd.cuda())
......
......@@ -33,7 +33,7 @@ class WanPreInfer:
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents
if self.scheduler.flag_df:
......@@ -45,7 +45,7 @@ class WanPreInfer:
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten()
if positive:
if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
......
......@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer):
else:
self.infer_func = self._infer_without_offload
self.infer_conditional = True
def switch_status(self):
self.infer_conditional = not self.infer_conditional
def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
......
......@@ -329,9 +329,9 @@ class WanModel:
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
noise_pred = self._infer_cond_uncond(inputs, positive=True)
noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
else:
noise_pred = self._infer_cond_uncond(inputs, positive=False)
noise_pred = self._infer_cond_uncond(inputs, infer_condition=False)
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
......@@ -339,13 +339,13 @@ class WanModel:
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
else:
# ==================== CFG Processing ====================
noise_pred_cond = self._infer_cond_uncond(inputs, positive=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, positive=False)
noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
else:
# ==================== No CFG ====================
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, positive=True)
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
......@@ -355,8 +355,10 @@ class WanModel:
self.transformer_weights.post_weights_to_cpu()
@torch.no_grad()
def _infer_cond_uncond(self, inputs, positive=True):
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=positive)
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)
......
......@@ -10,6 +10,7 @@ class BaseScheduler:
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.infer_condition = True # cfg status
def step_pre(self, step_index):
self.step_index = step_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment