Commit bba65ffd authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] fix encoder bugs. (#254)

parent 09d769cc
......@@ -174,8 +174,8 @@ class DefaultRunner(BaseRunner):
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
first_frame = self.read_image_input(self.config["image_path"])
last_frame = self.read_image_input(self.config["last_frame_path"])
first_frame, _ = self.read_image_input(self.config["image_path"])
last_frame, _ = self.read_image_input(self.config["last_frame_path"])
clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
text_encoder_output = self.run_text_encoder(prompt, first_frame)
......@@ -222,13 +222,13 @@ class DefaultRunner(BaseRunner):
# 2. main inference loop
latents, generator = self.run_segment(total_steps=total_steps)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents, generator)
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
self.end_run_segment()
self.end_run()
@ProfilingContext("Run VAE Decoder")
def run_vae_decoder(self, latents, generator):
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
images = self.vae_decoder.decode(latents)
......
......@@ -60,8 +60,6 @@ class QwenImageRunner(DefaultRunner):
self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_t2i
elif self.config["task"] == "i2i":
......@@ -156,7 +154,7 @@ class QwenImageRunner(DefaultRunner):
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator):
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae()
images = self.vae.decode(latents)
......@@ -174,7 +172,7 @@ class QwenImageRunner(DefaultRunner):
self.set_target_shape()
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
images = self.run_vae_decoder(latents)
image = images[0]
image.save(f"{self.config.save_video_path}")
......
......@@ -96,22 +96,22 @@ class WanVaceRunner(WanRunner):
assert len(frames) == len(ref_images)
if masks is None:
latents = self.vae_encoder.encode(frames)
latents = [self.vae_encoder.encode(frame.unsqueeze(0)) for frame in frames]
else:
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae_encoder.encode(inactive)
reactive = self.vae_encoder.encode(reactive)
inactive = [self.vae_encoder.encode(inact.unsqueeze(0)) for inact in inactive]
reactive = [self.vae_encoder.encode(react.unsqueeze(0)) for react in reactive]
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = self.vae_encoder.encode(refs)
ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0)) for ref in refs]
else:
ref_latent = self.vae_encoder.encode(refs)
ref_latent = [self.vae_encoder.encode(ref.unsqueeze(0)) for ref in refs]
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
......
#!/bin/bash
# set path and first
lightx2v_path=/mtc/gushiqiao/llmc_workspace/LightX2V
model_path=/data/nvme0/gushiqiao/models/Lightx2v_models/Wan2.1-R2V721-Audio-14B-720P
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=2
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment