default_runner.py 13.9 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
12
13
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
15

PengGao's avatar
PengGao committed
16
from .base_runner import BaseRunner
17
18


PengGao's avatar
PengGao committed
19
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
20
    def __init__(self, config):
PengGao's avatar
PengGao committed
21
        super().__init__(config)
22
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
23
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
24
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
25
26
27
28
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
29
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
30
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        self.set_init_device()
32

33
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
34
        logger.info("Initializing runner modules...")
35
36
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
37
38
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
39
40
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
41
42
43
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
44
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
45
46
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
47

48
    def set_init_device(self):
49
        if self.config.cpu_offload:
50
            self.init_device = torch.device("cpu")
51
        else:
52
            self.init_device = torch.device("cuda")
53

PengGao's avatar
PengGao committed
54
55
56
57
58
59
60
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
61
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
62

63
64
    @ProfilingContext("Load models")
    def load_model(self):
65
66
67
68
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
69
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
90
91
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
92
        self.config["use_prompt_enhancer"] = False
93
        if self.has_prompt_enhancer:
94
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
95
96
97
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
98
99
100
101
102
103
104
105
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
106

PengGao's avatar
PengGao committed
107
108
109
    def set_progress_callback(self, callback):
        self.progress_callback = callback

helloyongyang's avatar
helloyongyang committed
110
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
111
112
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
113
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
114
115
116
            # only for single segment, check stop signal every step
            if self.video_segment_num == 1:
                self.check_stop()
PengGao's avatar
PengGao committed
117
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
118
119
120
121

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

helloyongyang's avatar
helloyongyang committed
122
            with ProfilingContext4Debug("🚀 infer_main"):
123
124
125
126
127
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
128
            if self.progress_callback:
129
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
130

helloyongyang's avatar
helloyongyang committed
131
        return self.model.scheduler.latents
132

helloyongyang's avatar
helloyongyang committed
133
    def run_step(self):
134
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
135
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
136
137

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
138
139
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
140
141
142
143
144
145
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
146
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
147
        torch.cuda.empty_cache()
148
        gc.collect()
helloyongyang's avatar
helloyongyang committed
149

helloyongyang's avatar
helloyongyang committed
150
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
151
152
153
154
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
155
156
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
157

158
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
159
    def _run_input_encoder_local_i2v(self):
160
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
161
        img, img_ori = self.read_image_input(self.config["image_path"])
helloyongyang's avatar
helloyongyang committed
162
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
163
        vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
164
        text_encoder_output = self.run_text_encoder(prompt, img)
165
166
        torch.cuda.empty_cache()
        gc.collect()
167
168
169
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
170
    def _run_input_encoder_local_t2v(self):
171
172
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
173
174
        torch.cuda.empty_cache()
        gc.collect()
175
176
177
178
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
179

gushiqiao's avatar
gushiqiao committed
180
181
182
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
gushiqiao's avatar
gushiqiao committed
183
184
        first_frame, _ = self.read_image_input(self.config["image_path"])
        last_frame, _ = self.read_image_input(self.config["last_frame_path"])
gushiqiao's avatar
gushiqiao committed
185
186
187
188
189
190
191
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

gushiqiao's avatar
gushiqiao committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_vace(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        src_video = self.config.get("src_video", None)
        src_mask = self.config.get("src_mask", None)
        src_ref_images = self.config.get("src_ref_images", None)
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config.target_width, self.config.target_height),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        text_encoder_output = self.run_text_encoder(prompt)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

helloyongyang's avatar
helloyongyang committed
212
213
214
    def init_run(self):
        self.set_target_shape()
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
215
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
216
            self.model = self.load_transformer()
217
218
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
219
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
220
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
221
222
223
224
225

    @ProfilingContext("Run DiT")
    def run_main(self, total_steps=None):
        self.init_run()
        for segment_idx in range(self.video_segment_num):
226
            logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
helloyongyang's avatar
helloyongyang committed
227
            with ProfilingContext(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
LiangLiu's avatar
LiangLiu committed
228
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
229
230
231
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
232
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
233
234
235
236
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
                self.end_run_segment()
237
238
239
        self.end_run()

    @ProfilingContext("Run VAE Decoder")
gushiqiao's avatar
gushiqiao committed
240
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
241
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
242
            self.vae_decoder = self.load_vae_decoder()
243
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
244
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
245
            del self.vae_decoder
246
247
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
248
249
        return images

250
251
252
253
254
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
255
256
257
258
259
260
261
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
262
263
264
265
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

helloyongyang's avatar
helloyongyang committed
266
267
    def process_images_after_vae_decoder(self, save_video=True):
        self.gen_video = vae_to_comfyui_image(self.gen_video)
PengGao's avatar
PengGao committed
268
269
270
271
272

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
helloyongyang's avatar
helloyongyang committed
273
274
            self.gen_video = self.vfi_model.interpolate_frames(
                self.gen_video,
PengGao's avatar
PengGao committed
275
276
277
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
278

279
        if save_video:
PengGao's avatar
PengGao committed
280
281
282
283
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
284

285
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
286
                logger.info(f"🎬 Start to save video 🎬")
287

helloyongyang's avatar
helloyongyang committed
288
                save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
helloyongyang's avatar
helloyongyang committed
289
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
helloyongyang's avatar
helloyongyang committed
290
        return {"video": self.gen_video}
PengGao's avatar
PengGao committed
291

helloyongyang's avatar
helloyongyang committed
292
293
294
295
296
297
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()

helloyongyang's avatar
helloyongyang committed
298
        self.run_main()
helloyongyang's avatar
helloyongyang committed
299

helloyongyang's avatar
helloyongyang committed
300
        gen_video = self.process_images_after_vae_decoder(save_video=save_video)
helloyongyang's avatar
helloyongyang committed
301

302
303
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
304

helloyongyang's avatar
helloyongyang committed
305
        return gen_video