default_runner.py 17.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

yihuiwen's avatar
yihuiwen committed
11
from lightx2v.server.metrics import monitor_cli
helloyongyang's avatar
helloyongyang committed
12
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
13
from lightx2v.utils.generate_task_id import generate_task_id
14
from lightx2v.utils.global_paras import CALIB
15
from lightx2v.utils.memory_profiler import peak_memory_decorator
16
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
17
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
18

PengGao's avatar
PengGao committed
19
from .base_runner import BaseRunner
20
21


PengGao's avatar
PengGao committed
22
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
23
    def __init__(self, config):
PengGao's avatar
PengGao committed
24
        super().__init__(config)
25
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
26
        self.progress_callback = None
27
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
28
29
30
31
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
32
        if not self.has_prompt_enhancer:
33
            self.config["use_prompt_enhancer"] = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
34
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
35
        self.init_scheduler()
36

37
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
38
        logger.info("Initializing runner modules...")
39
40
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
41
42
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
43
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
44
45
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
46
47
48
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
49
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
50
51
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
52
53
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
54
55
56
        elif self.config["task"] == "s2v":
            self.run_input_encoder = self._run_input_encoder_local_s2v
        self.config.lock()  # lock config to avoid modification
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
57
58
59
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
60

61
    def set_init_device(self):
62
        if self.config["cpu_offload"]:
63
            self.init_device = torch.device("cpu")
64
        else:
Kane's avatar
Kane committed
65
            self.init_device = torch.device(self.config.get("run_device", "cuda"))
66

PengGao's avatar
PengGao committed
67
68
69
70
71
72
73
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
74
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
75

76
77
78
79
80
81
82
83
84
    def load_vsr_model(self):
        if "video_super_resolution" in self.config:
            from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper

            logger.info("Loading VSR model...")
            return VSRWrapper(self.config["video_super_resolution"]["model_path"])
        else:
            return None

85
    @ProfilingContext4DebugL2("Load models")
86
    def load_model(self):
87
88
89
90
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
91
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
92
        self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
113
    def set_inputs(self, inputs):
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)
helloyongyang's avatar
helloyongyang committed
131

PengGao's avatar
PengGao committed
132
133
134
    def set_progress_callback(self, callback):
        self.progress_callback = callback

135
    @peak_memory_decorator
PengGao's avatar
PengGao committed
136
137
138
139
    def run_segment(self, segment_idx=0):
        infer_steps = self.model.scheduler.infer_steps

        for step_index in range(infer_steps):
LiangLiu's avatar
LiangLiu committed
140
            # only for single segment, check stop signal every step
yihuiwen's avatar
yihuiwen committed
141
142
143
144
            with ProfilingContext4DebugL1(
                f"Run Dit every step",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
PengGao's avatar
PengGao committed
145
                metrics_labels=[step_index + 1, infer_steps],
yihuiwen's avatar
yihuiwen committed
146
147
148
            ):
                if self.video_segment_num == 1:
                    self.check_stop()
PengGao's avatar
PengGao committed
149
                logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
150

yihuiwen's avatar
yihuiwen committed
151
152
                with ProfilingContext4DebugL1("step_pre"):
                    self.model.scheduler.step_pre(step_index=step_index)
153

yihuiwen's avatar
yihuiwen committed
154
155
                with ProfilingContext4DebugL1("🚀 infer_main"):
                    self.model.infer(self.inputs)
156

yihuiwen's avatar
yihuiwen committed
157
158
                with ProfilingContext4DebugL1("step_post"):
                    self.model.scheduler.step_post()
159

yihuiwen's avatar
yihuiwen committed
160
                if self.progress_callback:
PengGao's avatar
PengGao committed
161
162
163
                    current_step = segment_idx * infer_steps + step_index + 1
                    total_all_steps = self.video_segment_num * infer_steps
                    self.progress_callback((current_step / total_all_steps) * 100, 100)
PengGao's avatar
PengGao committed
164

helloyongyang's avatar
helloyongyang committed
165
        return self.model.scheduler.latents
166

helloyongyang's avatar
helloyongyang committed
167
    def run_step(self):
168
        self.inputs = self.run_input_encoder()
PengGao's avatar
PengGao committed
169
        self.run_main()
helloyongyang's avatar
helloyongyang committed
170
171

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
172
        self.model.scheduler.clear()
LiangLiu's avatar
LiangLiu committed
173
        del self.inputs
174
        self.input_info = None
gushiqiao's avatar
gushiqiao committed
175
176
177
178
179
180
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
181
            del self.model
182
183
184
185
        if self.config.get("do_mm_calib", False):
            calib_path = os.path.join(os.getcwd(), "calib.pt")
            torch.save(CALIB, calib_path)
            logger.info(f"[CALIB] Saved calibration data successfully to: {calib_path}")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
186
        torch.cuda.empty_cache()
187
        gc.collect()
helloyongyang's avatar
helloyongyang committed
188

helloyongyang's avatar
helloyongyang committed
189
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
190
191
192
193
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
yihuiwen's avatar
yihuiwen committed
194
195
        if GET_RECORDER_MODE():
            width, height = img_ori.size
yihuiwen's avatar
yihuiwen committed
196
            monitor_cli.lightx2v_input_image_len.observe(width * height)
197
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
198
        self.input_info.original_size = img_ori.size
199
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
200

201
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
202
    def _run_input_encoder_local_i2v(self):
203
        img, img_ori = self.read_image_input(self.input_info.image_path)
helloyongyang's avatar
helloyongyang committed
204
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
205
206
207
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
208
209
        torch.cuda.empty_cache()
        gc.collect()
210
211
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

212
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
213
    def _run_input_encoder_local_t2v(self):
214
215
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"])  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
216
217
        torch.cuda.empty_cache()
        gc.collect()
218
219
220
221
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
222

223
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
224
    def _run_input_encoder_local_flf2v(self):
225
226
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
gushiqiao's avatar
gushiqiao committed
227
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
228
229
230
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
231
232
233
234
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

235
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
236
    def _run_input_encoder_local_vace(self):
237
238
239
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
gushiqiao's avatar
gushiqiao committed
240
241
242
243
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
244
            (self.config["target_width"], self.config["target_height"]),
gushiqiao's avatar
gushiqiao committed
245
246
247
        )
        self.src_ref_images = src_ref_images

248
249
250
        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
251
252
253
254
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

255
256
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
257
        text_encoder_output = self.run_text_encoder(self.input_info)
258
259
260
261
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

262
263
264
    def _run_input_encoder_local_s2v(self):
        pass

helloyongyang's avatar
helloyongyang committed
265
    def init_run(self):
266
        self.gen_video_final = None
helloyongyang's avatar
helloyongyang committed
267
        self.get_video_segment_num()
268

gushiqiao's avatar
gushiqiao committed
269
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
270
            self.model = self.load_transformer()
271
272
273

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
274
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
275

276
    @ProfilingContext4DebugL2("Run DiT")
PengGao's avatar
PengGao committed
277
    def run_main(self):
helloyongyang's avatar
helloyongyang committed
278
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
279
        if self.config.get("compile", False):
280
            self.model.select_graph_for_compile(self.input_info)
helloyongyang's avatar
helloyongyang committed
281
        for segment_idx in range(self.video_segment_num):
282
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
yihuiwen's avatar
yihuiwen committed
283
284
285
            with ProfilingContext4DebugL1(
                f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                recorder_mode=GET_RECORDER_MODE(),
yihuiwen's avatar
yihuiwen committed
286
287
                metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                metrics_labels=["DefaultRunner"],
yihuiwen's avatar
yihuiwen committed
288
            ):
LiangLiu's avatar
LiangLiu committed
289
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
290
291
292
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
PengGao's avatar
PengGao committed
293
                latents = self.run_segment(segment_idx)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
294
295
296
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
297
                self.end_run_segment(segment_idx)
298
        gen_video_final = self.process_images_after_vae_decoder()
299
        self.end_run()
300
        return gen_video_final
301

yihuiwen's avatar
yihuiwen committed
302
    @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
gushiqiao's avatar
gushiqiao committed
303
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
304
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
305
            self.vae_decoder = self.load_vae_decoder()
306
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
307
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
308
            del self.vae_decoder
309
310
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
311
312
        return images

313
314
315
316
317
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
318
319
320
321
322
323
324
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
325
326
327
328
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

329
330
    def process_images_after_vae_decoder(self):
        self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
PengGao's avatar
PengGao committed
331
332
333
334
335

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
336
337
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
PengGao's avatar
PengGao committed
338
339
340
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
341

342
343
344
        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
PengGao's avatar
PengGao committed
345
346
347
348
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
349

350
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
351
                logger.info(f"🎬 Start to save video 🎬")
352

353
354
355
356
                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

yihuiwen's avatar
yihuiwen committed
357
    @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
358
    def run_pipeline(self, input_info):
yihuiwen's avatar
yihuiwen committed
359
360
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_count.inc()
361
        self.input_info = input_info
PengGao's avatar
PengGao committed
362

helloyongyang's avatar
helloyongyang committed
363
        if self.config["use_prompt_enhancer"]:
364
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
365
366
367

        self.inputs = self.run_input_encoder()

368
        gen_video_final = self.run_main()
PengGao's avatar
PengGao committed
369

yihuiwen's avatar
yihuiwen committed
370
371
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_success.inc()
372
        return gen_video_final