Commit 067a2b61 authored by helloyongyang's avatar helloyongyang
Browse files

update default runner

parent b086210c
......@@ -218,15 +218,7 @@ class DefaultRunner(BaseRunner):
logger.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder()
self.set_target_shape()
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
def process_images_after_vae_decoder(self, images, save_video=True):
images = vae_to_comfyui_image(images)
if "video_frame_interpolation" in self.config:
......@@ -251,6 +243,18 @@ class DefaultRunner(BaseRunner):
save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder()
self.set_target_shape()
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
self.process_images_after_vae_decoder(images, save_video=save_video)
del latents, generator
torch.cuda.empty_cache()
gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment