"vscode:/vscode.git/clone" did not exist on "e634f83fe931108d080936ee2b17f878fa3f1ba6"
default_runner.py 15.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
12
from lightx2v.utils.generate_task_id import generate_task_id
13
from lightx2v.utils.memory_profiler import peak_memory_decorator
14
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
16

PengGao's avatar
PengGao committed
17
from .base_runner import BaseRunner
18
19


PengGao's avatar
PengGao committed
20
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
21
    def __init__(self, config):
PengGao's avatar
PengGao committed
22
        super().__init__(config)
23
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
24
        self.progress_callback = None
25
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
26
27
28
29
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
30
        if not self.has_prompt_enhancer:
31
            self.config["use_prompt_enhancer"] = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
32
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
33
        self.init_scheduler()
34

35
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
36
        logger.info("Initializing runner modules...")
37
38
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
39
40
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
41
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
42
43
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
44
45
46
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
47
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
48
49
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
50
51
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
52
53
54
        elif self.config["task"] == "s2v":
            self.run_input_encoder = self._run_input_encoder_local_s2v
        self.config.lock()  # lock config to avoid modification
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
55
56
57
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
58

59
    def set_init_device(self):
60
        if self.config["cpu_offload"]:
61
            self.init_device = torch.device("cpu")
62
        else:
63
            self.init_device = torch.device("cuda")
64

PengGao's avatar
PengGao committed
65
66
67
68
69
70
71
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
72
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
73

74
    @ProfilingContext4DebugL2("Load models")
75
    def load_model(self):
76
77
78
79
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
80
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
101
    def set_inputs(self, inputs):
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)
helloyongyang's avatar
helloyongyang committed
119

PengGao's avatar
PengGao committed
120
121
122
    def set_progress_callback(self, callback):
        self.progress_callback = callback

123
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
124
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
125
126
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
127
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
128
129
130
            # only for single segment, check stop signal every step
            if self.video_segment_num == 1:
                self.check_stop()
PengGao's avatar
PengGao committed
131
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
132

133
            with ProfilingContext4DebugL1("step_pre"):
134
135
                self.model.scheduler.step_pre(step_index=step_index)

136
            with ProfilingContext4DebugL1("🚀 infer_main"):
137
138
                self.model.infer(self.inputs)

139
            with ProfilingContext4DebugL1("step_post"):
140
141
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
142
            if self.progress_callback:
143
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
144

helloyongyang's avatar
helloyongyang committed
145
        return self.model.scheduler.latents
146

helloyongyang's avatar
helloyongyang committed
147
    def run_step(self):
148
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
149
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
150
151

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
152
        self.model.scheduler.clear()
LiangLiu's avatar
LiangLiu committed
153
        del self.inputs
154
        self.input_info = None
gushiqiao's avatar
gushiqiao committed
155
156
157
158
159
160
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
161
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
162
        torch.cuda.empty_cache()
163
        gc.collect()
helloyongyang's avatar
helloyongyang committed
164

helloyongyang's avatar
helloyongyang committed
165
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
166
167
168
169
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
170
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
171
        self.input_info.original_size = img_ori.size
172
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
173

174
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
175
    def _run_input_encoder_local_i2v(self):
176
        img, img_ori = self.read_image_input(self.input_info.image_path)
helloyongyang's avatar
helloyongyang committed
177
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
178
179
180
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
181
182
        torch.cuda.empty_cache()
        gc.collect()
183
184
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

185
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
186
    def _run_input_encoder_local_t2v(self):
187
188
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"])  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
189
190
        torch.cuda.empty_cache()
        gc.collect()
191
192
193
194
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
195

196
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
197
    def _run_input_encoder_local_flf2v(self):
198
199
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
gushiqiao's avatar
gushiqiao committed
200
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
201
202
203
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
204
205
206
207
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

208
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
209
    def _run_input_encoder_local_vace(self):
210
211
212
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
gushiqiao's avatar
gushiqiao committed
213
214
215
216
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
217
            (self.config["target_width"], self.config["target_height"]),
gushiqiao's avatar
gushiqiao committed
218
219
220
        )
        self.src_ref_images = src_ref_images

221
222
223
        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
gushiqiao's avatar
gushiqiao committed
224
225
226
227
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

228
229
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
230
        text_encoder_output = self.run_text_encoder(self.input_info)
231
232
233
234
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

235
236
237
    def _run_input_encoder_local_s2v(self):
        pass

helloyongyang's avatar
helloyongyang committed
238
    def init_run(self):
239
        self.gen_video_final = None
helloyongyang's avatar
helloyongyang committed
240
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
241
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
242
            self.model = self.load_transformer()
243
244
245

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
246
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
247

248
    @ProfilingContext4DebugL2("Run DiT")
helloyongyang's avatar
helloyongyang committed
249
250
    def run_main(self, total_steps=None):
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
251
        if self.config.get("compile", False):
252
            self.model.select_graph_for_compile(self.input_info)
helloyongyang's avatar
helloyongyang committed
253
        for segment_idx in range(self.video_segment_num):
254
255
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
            with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
LiangLiu's avatar
LiangLiu committed
256
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
257
258
259
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
260
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
261
262
263
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
264
                self.end_run_segment(segment_idx)
265
        gen_video_final = self.process_images_after_vae_decoder()
266
        self.end_run()
267
        return {"video": gen_video_final}
268

269
    @ProfilingContext4DebugL1("Run VAE Decoder")
gushiqiao's avatar
gushiqiao committed
270
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
271
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
272
            self.vae_decoder = self.load_vae_decoder()
273
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
274
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
275
            del self.vae_decoder
276
277
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
278
279
        return images

280
281
282
283
284
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
285
286
287
288
289
290
291
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
292
293
294
295
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

296
297
    def process_images_after_vae_decoder(self):
        self.gen_video_final = vae_to_comfyui_image(self.gen_video_final)
PengGao's avatar
PengGao committed
298
299
300
301
302

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
303
304
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
PengGao's avatar
PengGao committed
305
306
307
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
308

309
310
311
        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
PengGao's avatar
PengGao committed
312
313
314
315
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
316

317
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
318
                logger.info(f"🎬 Start to save video 🎬")
319

320
321
322
323
324
325
                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

    def run_pipeline(self, input_info):
        self.input_info = input_info
PengGao's avatar
PengGao committed
326

helloyongyang's avatar
helloyongyang committed
327
        if self.config["use_prompt_enhancer"]:
328
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
329
330
331

        self.inputs = self.run_input_encoder()

332
        gen_video_final = self.run_main()
PengGao's avatar
PengGao committed
333

334
        return gen_video_final