default_runner.py 18.6 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

yihuiwen's avatar
yihuiwen committed
11
from lightx2v.server.metrics import monitor_cli
helloyongyang's avatar
helloyongyang committed
12
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
13
from lightx2v.utils.generate_task_id import generate_task_id
14
from lightx2v.utils.global_paras import CALIB
15
from lightx2v.utils.memory_profiler import peak_memory_decorator
16
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
17
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
18
from lightx2v_platform.base.global_var import AI_DEVICE
PengGao's avatar
PengGao committed
19

PengGao's avatar
PengGao committed
20
from .base_runner import BaseRunner
21
22


PengGao's avatar
PengGao committed
23
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
24
    def __init__(self, config):
PengGao's avatar
PengGao committed
25
        super().__init__(config)
26
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
27
        self.progress_callback = None
28
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
29
30
31
32
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
33
        if not self.has_prompt_enhancer:
34
            self.config["use_prompt_enhancer"] = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
35
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
36
        self.init_scheduler()
37

38
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
39
        logger.info("Initializing runner modules...")
40
41
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
42
43
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
44
45
        if hasattr(self, "model"):
            self.model.set_scheduler(self.scheduler)  # set scheduler to model
46
47
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
48
49
50
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
51
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
52
53
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
54
55
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
56
57
58
        elif self.config["task"] == "s2v":
            self.run_input_encoder = self._run_input_encoder_local_s2v
        self.config.lock()  # lock config to avoid modification
59
        if self.config.get("compile", False) and hasattr(self.model, "comple"):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
60
61
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
62

63
    def set_init_device(self):
64
        if self.config["cpu_offload"]:
65
            self.init_device = torch.device("cpu")
66
        else:
67
            self.init_device = torch.device(AI_DEVICE)
68

PengGao's avatar
PengGao committed
69
70
71
72
73
74
75
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
76
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
77

78
79
80
81
82
83
84
85
86
    def load_vsr_model(self):
        if "video_super_resolution" in self.config:
            from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper

            logger.info("Loading VSR model...")
            return VSRWrapper(self.config["video_super_resolution"]["model_path"])
        else:
            return None

87
    @ProfilingContext4DebugL2("Load models")
88
    def load_model(self):
89
90
91
92
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
93
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
94
        self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
115
    def set_inputs(self, inputs):
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)
helloyongyang's avatar
helloyongyang committed
133

PengGao's avatar
PengGao committed
134
135
136
    def set_progress_callback(self, callback):
        self.progress_callback = callback

137
    @peak_memory_decorator
PengGao's avatar
PengGao committed
138
139
140
141
    def run_segment(self, segment_idx=0):
        infer_steps = self.model.scheduler.infer_steps

        for step_index in range(infer_steps):
LiangLiu's avatar
LiangLiu committed
142
            # only for single segment, check stop signal every step
yihuiwen's avatar
yihuiwen committed
143
144
145
146
            with ProfilingContext4DebugL1(
                f"Run Dit every step",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
PengGao's avatar
PengGao committed
147
                metrics_labels=[step_index + 1, infer_steps],
yihuiwen's avatar
yihuiwen committed
148
149
150
            ):
                if self.video_segment_num == 1:
                    self.check_stop()
PengGao's avatar
PengGao committed
151
                logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
152

yihuiwen's avatar
yihuiwen committed
153
154
                with ProfilingContext4DebugL1("step_pre"):
                    self.model.scheduler.step_pre(step_index=step_index)
155

yihuiwen's avatar
yihuiwen committed
156
157
                with ProfilingContext4DebugL1("🚀 infer_main"):
                    self.model.infer(self.inputs)
158

yihuiwen's avatar
yihuiwen committed
159
160
                with ProfilingContext4DebugL1("step_post"):
                    self.model.scheduler.step_post()
161

yihuiwen's avatar
yihuiwen committed
162
                if self.progress_callback:
PengGao's avatar
PengGao committed
163
164
165
                    current_step = segment_idx * infer_steps + step_index + 1
                    total_all_steps = self.video_segment_num * infer_steps
                    self.progress_callback((current_step / total_all_steps) * 100, 100)
PengGao's avatar
PengGao committed
166

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
167
168
169
170
        if segment_idx is not None and segment_idx == self.video_segment_num - 1:
            del self.inputs
            torch.cuda.empty_cache()

helloyongyang's avatar
helloyongyang committed
171
        return self.model.scheduler.latents
172

helloyongyang's avatar
helloyongyang committed
173
    def run_step(self):
174
        self.inputs = self.run_input_encoder()
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
175
176
177
178
179
180
        if hasattr(self, "sr_version") and self.sr_version is not None is not None:
            self.config_sr["is_sr_running"] = True
            self.inputs_sr = self.run_input_encoder()
            self.config_sr["is_sr_running"] = False

        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
181
182

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
183
        self.model.scheduler.clear()
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
184
185
        if hasattr(self, "inputs"):
            del self.inputs
186
        self.input_info = None
gushiqiao's avatar
gushiqiao committed
187
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
188
189
190
191
192
193
194
195
196
197
198
199
200
            if hasattr(self.model, "model") and len(self.model.model) == 2:  # MultiModelStruct
                for model in self.model.model:
                    if hasattr(model.transformer_infer, "offload_manager"):
                        del model.transformer_infer.offload_manager
                        torch.cuda.empty_cache()
                        gc.collect()
                    del model
            else:
                if hasattr(self.model.transformer_infer, "offload_manager"):
                    del self.model.transformer_infer.offload_manager
                    torch.cuda.empty_cache()
                    gc.collect()
                del self.model
201
202
203
204
        if self.config.get("do_mm_calib", False):
            calib_path = os.path.join(os.getcwd(), "calib.pt")
            torch.save(CALIB, calib_path)
            logger.info(f"[CALIB] Saved calibration data successfully to: {calib_path}")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
205
        torch.cuda.empty_cache()
206
        gc.collect()
helloyongyang's avatar
helloyongyang committed
207

helloyongyang's avatar
helloyongyang committed
208
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
209
210
211
212
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
yihuiwen's avatar
yihuiwen committed
213
214
        if GET_RECORDER_MODE():
            width, height = img_ori.size
yihuiwen's avatar
yihuiwen committed
215
            monitor_cli.lightx2v_input_image_len.observe(width * height)
Kane's avatar
Kane committed
216
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
217
        self.input_info.original_size = img_ori.size
218
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
219

220
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
221
    def _run_input_encoder_local_i2v(self):
222
        img, img_ori = self.read_image_input(self.input_info.image_path)
helloyongyang's avatar
helloyongyang committed
223
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
224
225
226
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
227
228
        torch.cuda.empty_cache()
        gc.collect()
229
230
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

231
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
232
    def _run_input_encoder_local_t2v(self):
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
233
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw()  # Important: set latent_shape in input_info
234
        text_encoder_output = self.run_text_encoder(self.input_info)
235
236
        torch.cuda.empty_cache()
        gc.collect()
237
238
239
240
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
241

242
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
243
    def _run_input_encoder_local_flf2v(self):
244
245
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
gushiqiao's avatar
gushiqiao committed
246
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
247
248
249
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
250
251
252
253
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

254
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
255
    def _run_input_encoder_local_vace(self):
256
257
258
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
gushiqiao's avatar
gushiqiao committed
259
260
261
262
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
263
            (self.config["target_width"], self.config["target_height"]),
gushiqiao's avatar
gushiqiao committed
264
265
266
        )
        self.src_ref_images = src_ref_images

267
268
269
        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
270
271
272
273
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

274
275
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
276
        text_encoder_output = self.run_text_encoder(self.input_info)
277
278
279
280
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

281
282
283
    def _run_input_encoder_local_s2v(self):
        pass

helloyongyang's avatar
helloyongyang committed
284
    def init_run(self):
285
        self.gen_video_final = None
helloyongyang's avatar
helloyongyang committed
286
        self.get_video_segment_num()
287

gushiqiao's avatar
gushiqiao committed
288
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
289
            self.model = self.load_transformer()
290
            self.model.set_scheduler(self.scheduler)
291
292
293

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
294
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
295

sandy's avatar
sandy committed
296
        if hasattr(self, "sr_version") and self.sr_version is not None:
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
297
298
299
300
301
302
            self.lq_latents_shape = self.model.scheduler.latents.shape
            self.model_sr.set_scheduler(self.scheduler_sr)
            self.config_sr["is_sr_running"] = True
            self.inputs_sr = self.run_input_encoder()
            self.config_sr["is_sr_running"] = False

303
    @ProfilingContext4DebugL2("Run DiT")
PengGao's avatar
PengGao committed
304
    def run_main(self):
helloyongyang's avatar
helloyongyang committed
305
        self.init_run()
306
        if self.config.get("compile", False) and hasattr(self.model, "comple"):
307
            self.model.select_graph_for_compile(self.input_info)
helloyongyang's avatar
helloyongyang committed
308
        for segment_idx in range(self.video_segment_num):
309
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
yihuiwen's avatar
yihuiwen committed
310
311
312
            with ProfilingContext4DebugL1(
                f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                recorder_mode=GET_RECORDER_MODE(),
yihuiwen's avatar
yihuiwen committed
313
314
                metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                metrics_labels=["DefaultRunner"],
yihuiwen's avatar
yihuiwen committed
315
            ):
LiangLiu's avatar
LiangLiu committed
316
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
317
318
319
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
PengGao's avatar
PengGao committed
320
                latents = self.run_segment(segment_idx)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
321
322
323
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
324
                self.end_run_segment(segment_idx)
325
        gen_video_final = self.process_images_after_vae_decoder()
326
        self.end_run()
327
        return gen_video_final
328

yihuiwen's avatar
yihuiwen committed
329
    @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
gushiqiao's avatar
gushiqiao committed
330
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
331
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
332
            self.vae_decoder = self.load_vae_decoder()
333
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
334
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
335
            del self.vae_decoder
336
337
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
338
339
        return images

340
341
342
343
344
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
345
346
347
348
349
350
351
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
352
353
354
355
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

356
357
    def process_images_after_vae_decoder(self):
        self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
PengGao's avatar
PengGao committed
358
359
360
361
362

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
363
364
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
PengGao's avatar
PengGao committed
365
366
367
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
368

369
370
371
        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
PengGao's avatar
PengGao committed
372
373
374
375
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
376

377
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
378
                logger.info(f"🎬 Start to save video 🎬")
379

380
381
382
383
                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

yihuiwen's avatar
yihuiwen committed
384
    @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
385
    def run_pipeline(self, input_info):
yihuiwen's avatar
yihuiwen committed
386
387
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_count.inc()
388
        self.input_info = input_info
PengGao's avatar
PengGao committed
389

helloyongyang's avatar
helloyongyang committed
390
        if self.config["use_prompt_enhancer"]:
391
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
392
393
394

        self.inputs = self.run_input_encoder()

395
        gen_video_final = self.run_main()
PengGao's avatar
PengGao committed
396

yihuiwen's avatar
yihuiwen committed
397
398
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_success.inc()
399
        return gen_video_final