default_runner.py 10.9 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2
3
4

from PIL import Image
from loguru import logger
5
6
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
7
8
import torch
import torch.distributed as dist
PengGao's avatar
PengGao committed
9

helloyongyang's avatar
helloyongyang committed
10
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
11
12
13
14
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image

PengGao's avatar
PengGao committed
15
from .base_runner import BaseRunner
16
17


PengGao's avatar
PengGao committed
18
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
19
    def __init__(self, config):
PengGao's avatar
PengGao committed
20
        super().__init__(config)
21
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
22
        self.progress_callback = None
23
24
25
26
27
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
28
29
30
        if not self.has_prompt_enhancer:
            self.config["use_prompt_enhancer"] = False
        self.set_init_device()
31

32
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
33
        logger.info("Initializing runner modules...")
34
35
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
36
37
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
38
39
40
41
        self.run_dit = self._run_dit_local
        self.run_vae_decoder = self._run_vae_decoder_local
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
42
        else:
43
            self.run_input_encoder = self._run_input_encoder_local_t2v
44

45
    def set_init_device(self):
46
47
48
49
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
50
            self.init_device = torch.device("cpu")
51
        else:
52
            self.init_device = torch.device("cuda")
53

PengGao's avatar
PengGao committed
54
55
56
57
58
59
60
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
61
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
62

63
64
    @ProfilingContext("Load models")
    def load_model(self):
65
66
67
68
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
69
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
90
91
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
92
        self.config["use_prompt_enhancer"] = False
93
        if self.has_prompt_enhancer:
94
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
95
96
97
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
98
99
100
101
102
103
104
105
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
106

PengGao's avatar
PengGao committed
107
108
109
    def set_progress_callback(self, callback):
        self.progress_callback = callback

110
    def run(self):
PengGao's avatar
PengGao committed
111
112
113
        total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
114
115
116
117
118
119
120
121
122
123

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
124
            if self.progress_callback:
125
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
126

127
128
        return self.model.scheduler.latents, self.model.scheduler.generator

129
130
131
132
133
134
135
    def run_step(self, step_index=0):
        self.init_scheduler()
        self.inputs = self.run_input_encoder()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
136
137

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
138
139
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
140
141
142
143
144
145
146
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
            self.model.post_weight.clear()
147
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
148
        torch.cuda.empty_cache()
149
        gc.collect()
helloyongyang's avatar
helloyongyang committed
150

151
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
152
    def _run_input_encoder_local_i2v(self):
153
154
155
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
helloyongyang's avatar
helloyongyang committed
156
        vae_encode_out = self.run_vae_encoder(img)
157
        text_encoder_output = self.run_text_encoder(prompt, img)
158
159
        torch.cuda.empty_cache()
        gc.collect()
160
161
162
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
163
    def _run_input_encoder_local_t2v(self):
164
165
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
166
167
        torch.cuda.empty_cache()
        gc.collect()
168
169
170
171
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
172
173

    @ProfilingContext("Run DiT")
helloyongyang's avatar
helloyongyang committed
174
    def _run_dit_local(self):
gushiqiao's avatar
gushiqiao committed
175
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
176
            self.model = self.load_transformer()
177
178
179
180
181
182
183
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
PengGao's avatar
PengGao committed
184
    def _run_vae_decoder_local(self, latents, generator):
gushiqiao's avatar
gushiqiao committed
185
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
186
            self.vae_decoder = self.load_vae_decoder()
187
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gushiqiao's avatar
gushiqiao committed
188
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
189
            del self.vae_decoder
190
191
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
192
193
        return images

194
195
196
197
198
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
199
200
201
202
203
204
205
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
206
207
208
209
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

210
211
212
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
PengGao's avatar
PengGao committed
213

214
        self.inputs = self.run_input_encoder()
PengGao's avatar
PengGao committed
215

helloyongyang's avatar
helloyongyang committed
216
        self.set_target_shape()
PengGao's avatar
PengGao committed
217

helloyongyang's avatar
helloyongyang committed
218
        latents, generator = self.run_dit()
PengGao's avatar
PengGao committed
219

220
        images = self.run_vae_decoder(latents, generator)
PengGao's avatar
PengGao committed
221
222
223
224
225
226
227
228
229
230
231
        images = vae_to_comfyui_image(images)

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
            images = self.vfi_model.interpolate_frames(
                images,
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
232

233
        if save_video:
PengGao's avatar
PengGao committed
234
235
236
237
238
239
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
            logger.info(f"Saving video to {self.config.save_video_path}")
            save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg")  # type: ignore
PengGao's avatar
PengGao committed
240

241
242
243
        del latents, generator
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
244

245
246
        # Return (images, audio) - audio is None for default runner
        return images, None