default_runner.py 18.2 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

yihuiwen's avatar
yihuiwen committed
11
from lightx2v.server.metrics import monitor_cli
helloyongyang's avatar
helloyongyang committed
12
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
13
from lightx2v.utils.generate_task_id import generate_task_id
14
from lightx2v.utils.global_paras import CALIB
15
from lightx2v.utils.memory_profiler import peak_memory_decorator
16
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
17
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
18

PengGao's avatar
PengGao committed
19
from .base_runner import BaseRunner
20
21


PengGao's avatar
PengGao committed
22
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
23
    def __init__(self, config):
PengGao's avatar
PengGao committed
24
        super().__init__(config)
25
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
26
        self.progress_callback = None
27
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
28
29
30
31
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
32
        if not self.has_prompt_enhancer:
33
            self.config["use_prompt_enhancer"] = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
34
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
35
        self.init_scheduler()
36

37
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
38
        logger.info("Initializing runner modules...")
39
40
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
41
42
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
43
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
44
45
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
46
47
48
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
49
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
50
51
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
52
53
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
54
55
56
        elif self.config["task"] == "s2v":
            self.run_input_encoder = self._run_input_encoder_local_s2v
        self.config.lock()  # lock config to avoid modification
57
        if self.config.get("compile", False) and hasattr(self.model, "comple"):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
58
59
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
60

61
    def set_init_device(self):
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
62
        self.run_device = self.config.get("run_device", "cuda")
63
        if self.config["cpu_offload"]:
64
            self.init_device = torch.device("cpu")
65
        else:
Kane's avatar
Kane committed
66
            self.init_device = torch.device(self.config.get("run_device", "cuda"))
67

PengGao's avatar
PengGao committed
68
69
70
71
72
73
74
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
75
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
76

77
78
79
80
81
82
83
84
85
    def load_vsr_model(self):
        if "video_super_resolution" in self.config:
            from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper

            logger.info("Loading VSR model...")
            return VSRWrapper(self.config["video_super_resolution"]["model_path"])
        else:
            return None

86
    @ProfilingContext4DebugL2("Load models")
87
    def load_model(self):
88
89
90
91
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
92
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
93
        self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
114
    def set_inputs(self, inputs):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)
helloyongyang's avatar
helloyongyang committed
132

PengGao's avatar
PengGao committed
133
134
135
    def set_progress_callback(self, callback):
        self.progress_callback = callback

136
    @peak_memory_decorator
PengGao's avatar
PengGao committed
137
138
139
140
    def run_segment(self, segment_idx=0):
        infer_steps = self.model.scheduler.infer_steps

        for step_index in range(infer_steps):
LiangLiu's avatar
LiangLiu committed
141
            # only for single segment, check stop signal every step
yihuiwen's avatar
yihuiwen committed
142
143
144
145
            with ProfilingContext4DebugL1(
                f"Run Dit every step",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
PengGao's avatar
PengGao committed
146
                metrics_labels=[step_index + 1, infer_steps],
yihuiwen's avatar
yihuiwen committed
147
148
149
            ):
                if self.video_segment_num == 1:
                    self.check_stop()
PengGao's avatar
PengGao committed
150
                logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
151

yihuiwen's avatar
yihuiwen committed
152
153
                with ProfilingContext4DebugL1("step_pre"):
                    self.model.scheduler.step_pre(step_index=step_index)
154

yihuiwen's avatar
yihuiwen committed
155
156
                with ProfilingContext4DebugL1("🚀 infer_main"):
                    self.model.infer(self.inputs)
157

yihuiwen's avatar
yihuiwen committed
158
159
                with ProfilingContext4DebugL1("step_post"):
                    self.model.scheduler.step_post()
160

yihuiwen's avatar
yihuiwen committed
161
                if self.progress_callback:
PengGao's avatar
PengGao committed
162
163
164
                    current_step = segment_idx * infer_steps + step_index + 1
                    total_all_steps = self.video_segment_num * infer_steps
                    self.progress_callback((current_step / total_all_steps) * 100, 100)
PengGao's avatar
PengGao committed
165

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
166
167
168
169
        if segment_idx is not None and segment_idx == self.video_segment_num - 1:
            del self.inputs
            torch.cuda.empty_cache()

helloyongyang's avatar
helloyongyang committed
170
        return self.model.scheduler.latents
171

helloyongyang's avatar
helloyongyang committed
172
    def run_step(self):
173
        self.inputs = self.run_input_encoder()
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
174
175
176
177
178
179
        if hasattr(self, "sr_version") and self.sr_version is not None is not None:
            self.config_sr["is_sr_running"] = True
            self.inputs_sr = self.run_input_encoder()
            self.config_sr["is_sr_running"] = False

        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
180
181

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
182
        self.model.scheduler.clear()
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
183
184
        if hasattr(self, "inputs"):
            del self.inputs
185
        self.input_info = None
gushiqiao's avatar
gushiqiao committed
186
187
188
189
190
191
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
192
            del self.model
193
194
195
196
        if self.config.get("do_mm_calib", False):
            calib_path = os.path.join(os.getcwd(), "calib.pt")
            torch.save(CALIB, calib_path)
            logger.info(f"[CALIB] Saved calibration data successfully to: {calib_path}")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
197
        torch.cuda.empty_cache()
198
        gc.collect()
helloyongyang's avatar
helloyongyang committed
199

helloyongyang's avatar
helloyongyang committed
200
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
201
202
203
204
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
yihuiwen's avatar
yihuiwen committed
205
206
        if GET_RECORDER_MODE():
            width, height = img_ori.size
yihuiwen's avatar
yihuiwen committed
207
            monitor_cli.lightx2v_input_image_len.observe(width * height)
Kane's avatar
Kane committed
208
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
209
        self.input_info.original_size = img_ori.size
210
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
211

212
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
213
    def _run_input_encoder_local_i2v(self):
214
        img, img_ori = self.read_image_input(self.input_info.image_path)
helloyongyang's avatar
helloyongyang committed
215
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
216
217
218
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
219
220
        torch.cuda.empty_cache()
        gc.collect()
221
222
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

223
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
224
    def _run_input_encoder_local_t2v(self):
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
225
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw()  # Important: set latent_shape in input_info
226
        text_encoder_output = self.run_text_encoder(self.input_info)
227
228
        torch.cuda.empty_cache()
        gc.collect()
229
230
231
232
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
233

234
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
235
    def _run_input_encoder_local_flf2v(self):
236
237
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
gushiqiao's avatar
gushiqiao committed
238
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
239
240
241
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
242
243
244
245
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

246
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
247
    def _run_input_encoder_local_vace(self):
248
249
250
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
gushiqiao's avatar
gushiqiao committed
251
252
253
254
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
255
            (self.config["target_width"], self.config["target_height"]),
gushiqiao's avatar
gushiqiao committed
256
257
258
        )
        self.src_ref_images = src_ref_images

259
260
261
        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
262
263
264
265
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

266
267
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
268
        text_encoder_output = self.run_text_encoder(self.input_info)
269
270
271
272
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

273
274
275
    def _run_input_encoder_local_s2v(self):
        pass

helloyongyang's avatar
helloyongyang committed
276
    def init_run(self):
277
        self.gen_video_final = None
helloyongyang's avatar
helloyongyang committed
278
        self.get_video_segment_num()
279

gushiqiao's avatar
gushiqiao committed
280
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
281
            self.model = self.load_transformer()
282
283
284

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
285
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
286

Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
287
288
289
290
291
292
293
294
        if hasattr(self, "sr_version") and self.sr_version is not None is not None:
            self.lq_latents_shape = self.model.scheduler.latents.shape
            self.model_sr.set_scheduler(self.scheduler_sr)

            self.config_sr["is_sr_running"] = True
            self.inputs_sr = self.run_input_encoder()
            self.config_sr["is_sr_running"] = False

295
    @ProfilingContext4DebugL2("Run DiT")
PengGao's avatar
PengGao committed
296
    def run_main(self):
helloyongyang's avatar
helloyongyang committed
297
        self.init_run()
298
        if self.config.get("compile", False) and hasattr(self.model, "comple"):
299
            self.model.select_graph_for_compile(self.input_info)
helloyongyang's avatar
helloyongyang committed
300
        for segment_idx in range(self.video_segment_num):
301
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
yihuiwen's avatar
yihuiwen committed
302
303
304
            with ProfilingContext4DebugL1(
                f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                recorder_mode=GET_RECORDER_MODE(),
yihuiwen's avatar
yihuiwen committed
305
306
                metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                metrics_labels=["DefaultRunner"],
yihuiwen's avatar
yihuiwen committed
307
            ):
LiangLiu's avatar
LiangLiu committed
308
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
309
310
311
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
PengGao's avatar
PengGao committed
312
                latents = self.run_segment(segment_idx)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
313
314
315
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
316
                self.end_run_segment(segment_idx)
317
        gen_video_final = self.process_images_after_vae_decoder()
318
        self.end_run()
319
        return gen_video_final
320

yihuiwen's avatar
yihuiwen committed
321
    @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
gushiqiao's avatar
gushiqiao committed
322
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
323
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
324
            self.vae_decoder = self.load_vae_decoder()
325
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
326
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
327
            del self.vae_decoder
328
329
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
330
331
        return images

332
333
334
335
336
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
337
338
339
340
341
342
343
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
344
345
346
347
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

348
349
    def process_images_after_vae_decoder(self):
        self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
PengGao's avatar
PengGao committed
350
351
352
353
354

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
355
356
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
PengGao's avatar
PengGao committed
357
358
359
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
360

361
362
363
        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
PengGao's avatar
PengGao committed
364
365
366
367
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
368

369
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
370
                logger.info(f"🎬 Start to save video 🎬")
371

372
373
374
375
                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

yihuiwen's avatar
yihuiwen committed
376
    @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
377
    def run_pipeline(self, input_info):
yihuiwen's avatar
yihuiwen committed
378
379
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_count.inc()
380
        self.input_info = input_info
PengGao's avatar
PengGao committed
381

helloyongyang's avatar
helloyongyang committed
382
        if self.config["use_prompt_enhancer"]:
383
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
384
385
386

        self.inputs = self.run_input_encoder()

387
        gen_video_final = self.run_main()
PengGao's avatar
PengGao committed
388

yihuiwen's avatar
yihuiwen committed
389
390
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_success.inc()
391
        return gen_video_final