default_runner.py 14.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
12
from lightx2v.utils.generate_task_id import generate_task_id
13
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
15

PengGao's avatar
PengGao committed
16
from .base_runner import BaseRunner
17
18


PengGao's avatar
PengGao committed
19
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
20
    def __init__(self, config):
PengGao's avatar
PengGao committed
21
        super().__init__(config)
22
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
23
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
24
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
25
26
27
28
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
29
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
30
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
32
        self.init_scheduler()
33

34
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
35
        logger.info("Initializing runner modules...")
36
37
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
38
39
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
40
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
41
42
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
43
44
45
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
46
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
47
48
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
49
50
51
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
52

53
    def set_init_device(self):
54
        if self.config.cpu_offload:
55
            self.init_device = torch.device("cpu")
56
        else:
57
            self.init_device = torch.device("cuda")
58

PengGao's avatar
PengGao committed
59
60
61
62
63
64
65
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
66
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
67

68
    @ProfilingContext4DebugL2("Load models")
69
    def load_model(self):
70
71
72
73
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
74
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
95
96
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
97
        self.config["use_prompt_enhancer"] = False
98
        if self.has_prompt_enhancer:
99
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
100
101
102
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
103
104
105
106
107
108
109
110
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
111

PengGao's avatar
PengGao committed
112
113
114
    def set_progress_callback(self, callback):
        self.progress_callback = callback

helloyongyang's avatar
helloyongyang committed
115
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
116
117
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
118
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
119
120
121
            # only for single segment, check stop signal every step
            if self.video_segment_num == 1:
                self.check_stop()
PengGao's avatar
PengGao committed
122
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
123

124
            with ProfilingContext4DebugL1("step_pre"):
125
126
                self.model.scheduler.step_pre(step_index=step_index)

127
            with ProfilingContext4DebugL1("🚀 infer_main"):
128
129
                self.model.infer(self.inputs)

130
            with ProfilingContext4DebugL1("step_post"):
131
132
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
133
            if self.progress_callback:
134
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
135

helloyongyang's avatar
helloyongyang committed
136
        return self.model.scheduler.latents
137

helloyongyang's avatar
helloyongyang committed
138
    def run_step(self):
139
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
140
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
141
142

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
143
144
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
145
146
147
148
149
150
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
151
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
152
        torch.cuda.empty_cache()
153
        gc.collect()
helloyongyang's avatar
helloyongyang committed
154

helloyongyang's avatar
helloyongyang committed
155
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
156
157
158
159
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
160
161
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
162

163
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
164
    def _run_input_encoder_local_i2v(self):
165
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
166
        img, img_ori = self.read_image_input(self.config["image_path"])
helloyongyang's avatar
helloyongyang committed
167
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
168
        vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
169
        text_encoder_output = self.run_text_encoder(prompt, img)
170
171
        torch.cuda.empty_cache()
        gc.collect()
172
173
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

174
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
175
    def _run_input_encoder_local_t2v(self):
176
177
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
178
179
        torch.cuda.empty_cache()
        gc.collect()
180
181
182
183
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
184

185
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
186
187
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
gushiqiao's avatar
gushiqiao committed
188
189
        first_frame, _ = self.read_image_input(self.config["image_path"])
        last_frame, _ = self.read_image_input(self.config["last_frame_path"])
gushiqiao's avatar
gushiqiao committed
190
191
192
193
194
195
196
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

197
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def _run_input_encoder_local_vace(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        src_video = self.config.get("src_video", None)
        src_mask = self.config.get("src_mask", None)
        src_ref_images = self.config.get("src_ref_images", None)
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config.target_width, self.config.target_height),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        text_encoder_output = self.run_text_encoder(prompt)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

helloyongyang's avatar
helloyongyang committed
217
218
219
    def init_run(self):
        self.set_target_shape()
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
220
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
221
            self.model = self.load_transformer()
222
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
223
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
224
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
225

226
    @ProfilingContext4DebugL2("Run DiT")
helloyongyang's avatar
helloyongyang committed
227
228
    def run_main(self, total_steps=None):
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
229
230
        if self.config.get("compile", False):
            self.model.select_graph_for_compile()
helloyongyang's avatar
helloyongyang committed
231
        for segment_idx in range(self.video_segment_num):
232
233
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
            with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
LiangLiu's avatar
LiangLiu committed
234
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
235
236
237
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
238
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
239
240
241
242
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
                self.end_run_segment()
243
244
        self.end_run()

245
    @ProfilingContext4DebugL1("Run VAE Decoder")
gushiqiao's avatar
gushiqiao committed
246
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
247
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
248
            self.vae_decoder = self.load_vae_decoder()
249
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
250
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
251
            del self.vae_decoder
252
253
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
254
255
        return images

256
257
258
259
260
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
261
262
263
264
265
266
267
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
268
269
270
271
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

helloyongyang's avatar
helloyongyang committed
272
273
    def process_images_after_vae_decoder(self, save_video=True):
        self.gen_video = vae_to_comfyui_image(self.gen_video)
PengGao's avatar
PengGao committed
274
275
276
277
278

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
helloyongyang's avatar
helloyongyang committed
279
280
            self.gen_video = self.vfi_model.interpolate_frames(
                self.gen_video,
PengGao's avatar
PengGao committed
281
282
283
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
284

285
        if save_video:
PengGao's avatar
PengGao committed
286
287
288
289
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
290

291
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
292
                logger.info(f"🎬 Start to save video 🎬")
293

helloyongyang's avatar
helloyongyang committed
294
                save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
helloyongyang's avatar
helloyongyang committed
295
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
helloyongyang's avatar
helloyongyang committed
296
        return {"video": self.gen_video}
PengGao's avatar
PengGao committed
297

helloyongyang's avatar
helloyongyang committed
298
299
300
301
302
303
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()

helloyongyang's avatar
helloyongyang committed
304
        self.run_main()
helloyongyang's avatar
helloyongyang committed
305

helloyongyang's avatar
helloyongyang committed
306
        gen_video = self.process_images_after_vae_decoder(save_video=save_video)
helloyongyang's avatar
helloyongyang committed
307

308
309
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
310

helloyongyang's avatar
helloyongyang committed
311
        return gen_video