default_runner.py 16.9 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

yihuiwen's avatar
yihuiwen committed
11
from lightx2v.server.metrics import monitor_cli
helloyongyang's avatar
helloyongyang committed
12
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
13
from lightx2v.utils.generate_task_id import generate_task_id
14
from lightx2v.utils.memory_profiler import peak_memory_decorator
15
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
16
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
17

PengGao's avatar
PengGao committed
18
from .base_runner import BaseRunner
19
20


PengGao's avatar
PengGao committed
21
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
22
    def __init__(self, config):
PengGao's avatar
PengGao committed
23
        super().__init__(config)
24
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
25
        self.progress_callback = None
26
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
27
28
29
30
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        if not self.has_prompt_enhancer:
32
            self.config["use_prompt_enhancer"] = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
33
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
34
        self.init_scheduler()
35

36
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
37
        logger.info("Initializing runner modules...")
38
39
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
40
41
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
42
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
43
44
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
45
46
47
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
48
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
49
50
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
51
52
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
53
54
55
        elif self.config["task"] == "s2v":
            self.run_input_encoder = self._run_input_encoder_local_s2v
        self.config.lock()  # lock config to avoid modification
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
56
57
58
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
59

60
    def set_init_device(self):
61
        if self.config["cpu_offload"]:
62
            self.init_device = torch.device("cpu")
63
        else:
64
            self.init_device = torch.device("cuda")
65

PengGao's avatar
PengGao committed
66
67
68
69
70
71
72
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
73
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
74

75
76
77
78
79
80
81
82
83
    def load_vsr_model(self):
        if "video_super_resolution" in self.config:
            from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper

            logger.info("Loading VSR model...")
            return VSRWrapper(self.config["video_super_resolution"]["model_path"])
        else:
            return None

84
    @ProfilingContext4DebugL2("Load models")
85
    def load_model(self):
86
87
88
89
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
90
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
91
        self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
112
    def set_inputs(self, inputs):
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)
helloyongyang's avatar
helloyongyang committed
130

PengGao's avatar
PengGao committed
131
132
133
    def set_progress_callback(self, callback):
        self.progress_callback = callback

134
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
135
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
136
137
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
138
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
139
            # only for single segment, check stop signal every step
yihuiwen's avatar
yihuiwen committed
140
141
142
143
144
145
146
147
148
            with ProfilingContext4DebugL1(
                f"Run Dit every step",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
                metrics_labels=[step_index + 1, total_steps],
            ):
                if self.video_segment_num == 1:
                    self.check_stop()
                logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
149

yihuiwen's avatar
yihuiwen committed
150
151
                with ProfilingContext4DebugL1("step_pre"):
                    self.model.scheduler.step_pre(step_index=step_index)
152

yihuiwen's avatar
yihuiwen committed
153
154
                with ProfilingContext4DebugL1("🚀 infer_main"):
                    self.model.infer(self.inputs)
155

yihuiwen's avatar
yihuiwen committed
156
157
                with ProfilingContext4DebugL1("step_post"):
                    self.model.scheduler.step_post()
158

yihuiwen's avatar
yihuiwen committed
159
160
                if self.progress_callback:
                    self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
161

helloyongyang's avatar
helloyongyang committed
162
        return self.model.scheduler.latents
163

helloyongyang's avatar
helloyongyang committed
164
    def run_step(self):
165
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
166
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
167
168

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
169
        self.model.scheduler.clear()
LiangLiu's avatar
LiangLiu committed
170
        del self.inputs
171
        self.input_info = None
gushiqiao's avatar
gushiqiao committed
172
173
174
175
176
177
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
178
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
179
        torch.cuda.empty_cache()
180
        gc.collect()
helloyongyang's avatar
helloyongyang committed
181

helloyongyang's avatar
helloyongyang committed
182
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
183
184
185
186
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
yihuiwen's avatar
yihuiwen committed
187
188
        if GET_RECORDER_MODE():
            width, height = img_ori.size
yihuiwen's avatar
yihuiwen committed
189
            monitor_cli.lightx2v_input_image_len.observe(width * height)
190
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
191
        self.input_info.original_size = img_ori.size
192
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
193

194
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
195
    def _run_input_encoder_local_i2v(self):
196
        img, img_ori = self.read_image_input(self.input_info.image_path)
helloyongyang's avatar
helloyongyang committed
197
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
198
199
200
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
201
202
        torch.cuda.empty_cache()
        gc.collect()
203
204
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

205
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
206
    def _run_input_encoder_local_t2v(self):
207
208
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"])  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
209
210
        torch.cuda.empty_cache()
        gc.collect()
211
212
213
214
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
215

216
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
217
    def _run_input_encoder_local_flf2v(self):
218
219
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
gushiqiao's avatar
gushiqiao committed
220
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
221
222
223
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
224
225
226
227
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

228
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
229
    def _run_input_encoder_local_vace(self):
230
231
232
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
gushiqiao's avatar
gushiqiao committed
233
234
235
236
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
237
            (self.config["target_width"], self.config["target_height"]),
gushiqiao's avatar
gushiqiao committed
238
239
240
        )
        self.src_ref_images = src_ref_images

241
242
243
        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
244
245
246
247
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

248
249
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
250
        text_encoder_output = self.run_text_encoder(self.input_info)
251
252
253
254
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

255
256
257
    def _run_input_encoder_local_s2v(self):
        pass

helloyongyang's avatar
helloyongyang committed
258
    def init_run(self):
259
        self.gen_video_final = None
helloyongyang's avatar
helloyongyang committed
260
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
261
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
262
            self.model = self.load_transformer()
263
264
265

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
266
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
267

268
    @ProfilingContext4DebugL2("Run DiT")
helloyongyang's avatar
helloyongyang committed
269
270
    def run_main(self, total_steps=None):
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
271
        if self.config.get("compile", False):
272
            self.model.select_graph_for_compile(self.input_info)
helloyongyang's avatar
helloyongyang committed
273
        for segment_idx in range(self.video_segment_num):
274
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
yihuiwen's avatar
yihuiwen committed
275
276
277
            with ProfilingContext4DebugL1(
                f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                recorder_mode=GET_RECORDER_MODE(),
yihuiwen's avatar
yihuiwen committed
278
279
                metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                metrics_labels=["DefaultRunner"],
yihuiwen's avatar
yihuiwen committed
280
            ):
LiangLiu's avatar
LiangLiu committed
281
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
282
283
284
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
285
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
286
287
288
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
289
                self.end_run_segment(segment_idx)
290
        gen_video_final = self.process_images_after_vae_decoder()
291
        self.end_run()
292
        return gen_video_final
293

yihuiwen's avatar
yihuiwen committed
294
    @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
gushiqiao's avatar
gushiqiao committed
295
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
296
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
297
            self.vae_decoder = self.load_vae_decoder()
298
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
299
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
300
            del self.vae_decoder
301
302
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
303
304
        return images

305
306
307
308
309
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
310
311
312
313
314
315
316
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
317
318
319
320
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

321
322
    def process_images_after_vae_decoder(self):
        self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
PengGao's avatar
PengGao committed
323
324
325
326
327

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
328
329
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
PengGao's avatar
PengGao committed
330
331
332
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
333

334
335
336
        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
PengGao's avatar
PengGao committed
337
338
339
340
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
341

342
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
343
                logger.info(f"🎬 Start to save video 🎬")
344

345
346
347
348
                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

yihuiwen's avatar
yihuiwen committed
349
    @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
350
    def run_pipeline(self, input_info):
yihuiwen's avatar
yihuiwen committed
351
352
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_count.inc()
353
        self.input_info = input_info
PengGao's avatar
PengGao committed
354

helloyongyang's avatar
helloyongyang committed
355
        if self.config["use_prompt_enhancer"]:
356
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
357
358
359

        self.inputs = self.run_input_encoder()

360
        gen_video_final = self.run_main()
PengGao's avatar
PengGao committed
361

yihuiwen's avatar
yihuiwen committed
362
363
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_success.inc()
364
        return gen_video_final