default_runner.py 14.5 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
12
from lightx2v.utils.generate_task_id import generate_task_id
13
from lightx2v.utils.memory_profiler import peak_memory_decorator
14
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
16

PengGao's avatar
PengGao committed
17
from .base_runner import BaseRunner
18
19


PengGao's avatar
PengGao committed
20
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
21
    def __init__(self, config):
PengGao's avatar
PengGao committed
22
        super().__init__(config)
23
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
24
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
25
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
26
27
28
29
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
30
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
31
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
32
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
33
        self.init_scheduler()
34

35
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
36
        logger.info("Initializing runner modules...")
37
38
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
39
40
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
41
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
42
43
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
44
45
46
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
47
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
48
49
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
50
51
52
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
53

54
    def set_init_device(self):
55
        if self.config.cpu_offload:
56
            self.init_device = torch.device("cpu")
57
        else:
58
            self.init_device = torch.device("cuda")
59

PengGao's avatar
PengGao committed
60
61
62
63
64
65
66
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
67
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
68

69
    @ProfilingContext4DebugL2("Load models")
70
    def load_model(self):
71
72
73
74
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
75
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
96
97
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
98
        self.config["use_prompt_enhancer"] = False
99
        if self.has_prompt_enhancer:
100
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
101
102
103
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
104
105
106
107
108
109
110
111
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
112

PengGao's avatar
PengGao committed
113
114
115
    def set_progress_callback(self, callback):
        self.progress_callback = callback

116
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
117
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
118
119
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
120
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
121
122
123
            # only for single segment, check stop signal every step
            if self.video_segment_num == 1:
                self.check_stop()
PengGao's avatar
PengGao committed
124
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
125

126
            with ProfilingContext4DebugL1("step_pre"):
127
128
                self.model.scheduler.step_pre(step_index=step_index)

129
            with ProfilingContext4DebugL1("🚀 infer_main"):
130
131
                self.model.infer(self.inputs)

132
            with ProfilingContext4DebugL1("step_post"):
133
134
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
135
            if self.progress_callback:
136
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
137

helloyongyang's avatar
helloyongyang committed
138
        return self.model.scheduler.latents
139

helloyongyang's avatar
helloyongyang committed
140
    def run_step(self):
141
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
142
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
143
144

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
145
        self.model.scheduler.clear()
LiangLiu's avatar
LiangLiu committed
146
        del self.inputs
gushiqiao's avatar
gushiqiao committed
147
148
149
150
151
152
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
153
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
154
        torch.cuda.empty_cache()
155
        gc.collect()
helloyongyang's avatar
helloyongyang committed
156

helloyongyang's avatar
helloyongyang committed
157
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
158
159
160
161
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
162
163
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
164

165
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
166
    def _run_input_encoder_local_i2v(self):
167
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
168
        img, img_ori = self.read_image_input(self.config["image_path"])
helloyongyang's avatar
helloyongyang committed
169
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
170
        vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
171
        text_encoder_output = self.run_text_encoder(prompt, img)
172
173
        torch.cuda.empty_cache()
        gc.collect()
174
175
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

176
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
177
    def _run_input_encoder_local_t2v(self):
178
179
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
180
181
        torch.cuda.empty_cache()
        gc.collect()
182
183
184
185
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
186

187
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
188
189
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
gushiqiao's avatar
gushiqiao committed
190
191
        first_frame, _ = self.read_image_input(self.config["image_path"])
        last_frame, _ = self.read_image_input(self.config["last_frame_path"])
gushiqiao's avatar
gushiqiao committed
192
193
194
195
196
197
198
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

199
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def _run_input_encoder_local_vace(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        src_video = self.config.get("src_video", None)
        src_mask = self.config.get("src_mask", None)
        src_ref_images = self.config.get("src_ref_images", None)
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config.target_width, self.config.target_height),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        text_encoder_output = self.run_text_encoder(prompt)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

helloyongyang's avatar
helloyongyang committed
219
220
221
    def init_run(self):
        self.set_target_shape()
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
222
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
223
            self.model = self.load_transformer()
224
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
225
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
226
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
227

228
    @ProfilingContext4DebugL2("Run DiT")
helloyongyang's avatar
helloyongyang committed
229
230
    def run_main(self, total_steps=None):
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
231
232
        if self.config.get("compile", False):
            self.model.select_graph_for_compile()
helloyongyang's avatar
helloyongyang committed
233
        for segment_idx in range(self.video_segment_num):
234
235
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
            with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
LiangLiu's avatar
LiangLiu committed
236
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
237
238
239
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
240
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
241
242
243
244
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
                self.end_run_segment()
245
246
        self.end_run()

247
    @ProfilingContext4DebugL1("Run VAE Decoder")
gushiqiao's avatar
gushiqiao committed
248
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
249
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
250
            self.vae_decoder = self.load_vae_decoder()
251
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
252
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
253
            del self.vae_decoder
254
255
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
256
257
        return images

258
259
260
261
262
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
263
264
265
266
267
268
269
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
270
271
272
273
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

helloyongyang's avatar
helloyongyang committed
274
275
    def process_images_after_vae_decoder(self, save_video=True):
        self.gen_video = vae_to_comfyui_image(self.gen_video)
PengGao's avatar
PengGao committed
276
277
278
279
280

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
helloyongyang's avatar
helloyongyang committed
281
282
            self.gen_video = self.vfi_model.interpolate_frames(
                self.gen_video,
PengGao's avatar
PengGao committed
283
284
285
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
286

287
        if save_video:
PengGao's avatar
PengGao committed
288
289
290
291
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
292

293
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
294
                logger.info(f"🎬 Start to save video 🎬")
295

helloyongyang's avatar
helloyongyang committed
296
                save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
helloyongyang's avatar
helloyongyang committed
297
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
LiangLiu's avatar
LiangLiu committed
298
299
300
        if self.config.get("return_video", False):
            return {"video": self.gen_video}
        return {"video": None}
PengGao's avatar
PengGao committed
301

helloyongyang's avatar
helloyongyang committed
302
303
304
305
306
307
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()

helloyongyang's avatar
helloyongyang committed
308
        self.run_main()
helloyongyang's avatar
helloyongyang committed
309

helloyongyang's avatar
helloyongyang committed
310
        gen_video = self.process_images_after_vae_decoder(save_video=save_video)
helloyongyang's avatar
helloyongyang committed
311

312
313
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
314

helloyongyang's avatar
helloyongyang committed
315
        return gen_video