default_runner.py 4.68 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
import gc
import torch
import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
6
from lightx2v.utils.prompt_enhancer import PromptEnhancer
helloyongyang's avatar
helloyongyang committed
7
from lightx2v.utils.envs import *
8
from lightx2v.utils.memory_profiler import peak_memory_decorator
9
from loguru import logger
10
11
12


class DefaultRunner:
helloyongyang's avatar
helloyongyang committed
13
14
    def __init__(self, config):
        self.config = config
helloyongyang's avatar
helloyongyang committed
15
        if self.config.prompt_enhancer is not None and self.config.task == "t2v":
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
16
            self.load_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
17
18
        self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
19
20
21
22
23
24
25
    @ProfilingContext("Load prompt enhancer")
    def load_prompt_enhancer(self):
        gpu_count = torch.cuda.device_count()
        if gpu_count == 1:
            logger.info("Only one GPU, use prompt enhancer cpu offload")
            raise NotImplementedError("prompt enhancer cpu offload is not supported.")
        self.prompt_enhancer = PromptEnhancer(model_name=self.config.prompt_enhancer, device_map="cuda:1")
helloyongyang's avatar
helloyongyang committed
26
        self.config["use_prompt_enhancer"] = True  # Set use_prompt_enhancer to True now. (Default is False)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
27

helloyongyang's avatar
helloyongyang committed
28
29
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
helloyongyang's avatar
helloyongyang committed
30
        self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
31
32
33
34
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")

helloyongyang's avatar
helloyongyang committed
35
36
37
38
39
40
    def run_input_encoder(self):
        image_encoder_output = None
        if self.config["task"] == "i2v":
            with ProfilingContext("Run Img Encoder"):
                image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
        with ProfilingContext("Run Text Encoder"):
helloyongyang's avatar
helloyongyang committed
41
42
            prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
            text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
43
44
45
46
47
        self.set_target_shape()
        self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}

        gc.collect()
        torch.cuda.empty_cache()
48

49
    @peak_memory_decorator
50
51
    def run(self):
        for step_index in range(self.model.scheduler.infer_steps):
52
            logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
53
54
55
56
57
58
59
60
61
62
63
64
65

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

        return self.model.scheduler.latents, self.model.scheduler.generator

    def run_step(self, step_index=0):
helloyongyang's avatar
helloyongyang committed
66
67
68
        self.init_scheduler()
        self.run_input_encoder()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
69
70
71
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
72
73

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
74
75
76
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
        torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
77
78

    @ProfilingContext("Run VAE")
79
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
80
81
82
83
84
85
86
    def run_vae(self, latents, generator):
        images = self.vae_model.decode(latents, generator=generator, config=self.config)
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
87
            if self.config.model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_skyreels_v2_df"]:
88
                cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1))
helloyongyang's avatar
helloyongyang committed
89
            else:
90
                save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
helloyongyang's avatar
helloyongyang committed
91
92

    def run_pipeline(self):
helloyongyang's avatar
helloyongyang committed
93
94
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.prompt_enhancer(self.config["prompt"])
helloyongyang's avatar
helloyongyang committed
95
96
97
98
99
100
101
        self.init_scheduler()
        self.run_input_encoder()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        images = self.run_vae(latents, generator)
        self.save_video(images)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
102
103
104
        del latents, generator, images
        gc.collect()
        torch.cuda.empty_cache()