default_runner.py 16.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

yihuiwen's avatar
yihuiwen committed
11
from lightx2v.server.metrics import monitor_cli
helloyongyang's avatar
helloyongyang committed
12
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
13
from lightx2v.utils.generate_task_id import generate_task_id
14
from lightx2v.utils.memory_profiler import peak_memory_decorator
15
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
16
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
17

PengGao's avatar
PengGao committed
18
from .base_runner import BaseRunner
19
20


PengGao's avatar
PengGao committed
21
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
22
    def __init__(self, config):
PengGao's avatar
PengGao committed
23
        super().__init__(config)
24
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
25
        self.progress_callback = None
26
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
27
28
29
30
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        if not self.has_prompt_enhancer:
32
            self.config["use_prompt_enhancer"] = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
33
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
34
        self.init_scheduler()
35

36
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
37
        logger.info("Initializing runner modules...")
38
39
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
40
41
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
42
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
43
44
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
45
46
47
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
48
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
49
50
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
51
52
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
53
54
55
        elif self.config["task"] == "s2v":
            self.run_input_encoder = self._run_input_encoder_local_s2v
        self.config.lock()  # lock config to avoid modification
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
56
57
58
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
59

60
    def set_init_device(self):
61
        if self.config["cpu_offload"]:
62
            self.init_device = torch.device("cpu")
63
        else:
64
            self.init_device = torch.device("cuda")
65

PengGao's avatar
PengGao committed
66
67
68
69
70
71
72
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
73
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
74

75
    @ProfilingContext4DebugL2("Load models")
76
    def load_model(self):
77
78
79
80
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
81
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
82

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
102
    def set_inputs(self, inputs):
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)
helloyongyang's avatar
helloyongyang committed
120

PengGao's avatar
PengGao committed
121
122
123
    def set_progress_callback(self, callback):
        self.progress_callback = callback

124
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
125
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
126
127
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
128
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
129
            # only for single segment, check stop signal every step
yihuiwen's avatar
yihuiwen committed
130
131
132
133
134
135
136
137
138
            with ProfilingContext4DebugL1(
                f"Run Dit every step",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
                metrics_labels=[step_index + 1, total_steps],
            ):
                if self.video_segment_num == 1:
                    self.check_stop()
                logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
139

yihuiwen's avatar
yihuiwen committed
140
141
                with ProfilingContext4DebugL1("step_pre"):
                    self.model.scheduler.step_pre(step_index=step_index)
142

yihuiwen's avatar
yihuiwen committed
143
144
                with ProfilingContext4DebugL1("🚀 infer_main"):
                    self.model.infer(self.inputs)
145

yihuiwen's avatar
yihuiwen committed
146
147
                with ProfilingContext4DebugL1("step_post"):
                    self.model.scheduler.step_post()
148

yihuiwen's avatar
yihuiwen committed
149
150
                if self.progress_callback:
                    self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
151

helloyongyang's avatar
helloyongyang committed
152
        return self.model.scheduler.latents
153

helloyongyang's avatar
helloyongyang committed
154
    def run_step(self):
155
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
156
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
157
158

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
159
        self.model.scheduler.clear()
LiangLiu's avatar
LiangLiu committed
160
        del self.inputs
161
        self.input_info = None
gushiqiao's avatar
gushiqiao committed
162
163
164
165
166
167
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
168
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
169
        torch.cuda.empty_cache()
170
        gc.collect()
helloyongyang's avatar
helloyongyang committed
171

helloyongyang's avatar
helloyongyang committed
172
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
173
174
175
176
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
yihuiwen's avatar
yihuiwen committed
177
178
        if GET_RECORDER_MODE():
            width, height = img_ori.size
yihuiwen's avatar
yihuiwen committed
179
            monitor_cli.lightx2v_input_image_len.observe(width * height)
180
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
181
        self.input_info.original_size = img_ori.size
182
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
183

184
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
185
    def _run_input_encoder_local_i2v(self):
186
        img, img_ori = self.read_image_input(self.input_info.image_path)
helloyongyang's avatar
helloyongyang committed
187
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
188
189
190
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
191
192
        torch.cuda.empty_cache()
        gc.collect()
193
194
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

195
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
196
    def _run_input_encoder_local_t2v(self):
197
198
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"])  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
199
200
        torch.cuda.empty_cache()
        gc.collect()
201
202
203
204
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
205

206
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
207
    def _run_input_encoder_local_flf2v(self):
208
209
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
gushiqiao's avatar
gushiqiao committed
210
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
211
212
213
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
214
215
216
217
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

218
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
219
    def _run_input_encoder_local_vace(self):
220
221
222
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
gushiqiao's avatar
gushiqiao committed
223
224
225
226
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
227
            (self.config["target_width"], self.config["target_height"]),
gushiqiao's avatar
gushiqiao committed
228
229
230
        )
        self.src_ref_images = src_ref_images

231
232
233
        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
234
235
236
237
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

238
239
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
240
        text_encoder_output = self.run_text_encoder(self.input_info)
241
242
243
244
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

245
246
247
    def _run_input_encoder_local_s2v(self):
        pass

helloyongyang's avatar
helloyongyang committed
248
    def init_run(self):
249
        self.gen_video_final = None
helloyongyang's avatar
helloyongyang committed
250
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
251
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
252
            self.model = self.load_transformer()
253
254
255

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
256
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
257

258
    @ProfilingContext4DebugL2("Run DiT")
helloyongyang's avatar
helloyongyang committed
259
260
    def run_main(self, total_steps=None):
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
261
        if self.config.get("compile", False):
262
            self.model.select_graph_for_compile(self.input_info)
helloyongyang's avatar
helloyongyang committed
263
        for segment_idx in range(self.video_segment_num):
264
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
yihuiwen's avatar
yihuiwen committed
265
266
267
            with ProfilingContext4DebugL1(
                f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                recorder_mode=GET_RECORDER_MODE(),
yihuiwen's avatar
yihuiwen committed
268
269
                metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                metrics_labels=["DefaultRunner"],
yihuiwen's avatar
yihuiwen committed
270
            ):
LiangLiu's avatar
LiangLiu committed
271
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
272
273
274
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
275
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
276
277
278
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
279
                self.end_run_segment(segment_idx)
280
        gen_video_final = self.process_images_after_vae_decoder()
281
        self.end_run()
282
        return {"video": gen_video_final}
283

yihuiwen's avatar
yihuiwen committed
284
    @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
gushiqiao's avatar
gushiqiao committed
285
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
286
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
287
            self.vae_decoder = self.load_vae_decoder()
288
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
289
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
290
            del self.vae_decoder
291
292
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
293
294
        return images

295
296
297
298
299
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
300
301
302
303
304
305
306
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
307
308
309
310
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

311
312
    def process_images_after_vae_decoder(self):
        self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
PengGao's avatar
PengGao committed
313
314
315
316
317

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
318
319
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
PengGao's avatar
PengGao committed
320
321
322
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
323

324
325
326
        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
PengGao's avatar
PengGao committed
327
328
329
330
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
331

332
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
333
                logger.info(f"🎬 Start to save video 🎬")
334

335
336
337
338
                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

yihuiwen's avatar
yihuiwen committed
339
    @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
340
    def run_pipeline(self, input_info):
yihuiwen's avatar
yihuiwen committed
341
342
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_count.inc()
343
        self.input_info = input_info
PengGao's avatar
PengGao committed
344

helloyongyang's avatar
helloyongyang committed
345
        if self.config["use_prompt_enhancer"]:
346
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
347
348
349

        self.inputs = self.run_input_encoder()

350
        gen_video_final = self.run_main()
PengGao's avatar
PengGao committed
351

yihuiwen's avatar
yihuiwen committed
352
353
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_success.inc()
354
        return gen_video_final