default_runner.py 13.6 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
12
13
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
helloyongyang's avatar
helloyongyang committed
14
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
15

PengGao's avatar
PengGao committed
16
from .base_runner import BaseRunner
17
18


PengGao's avatar
PengGao committed
19
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
20
    def __init__(self, config):
PengGao's avatar
PengGao committed
21
        super().__init__(config)
22
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
23
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
24
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
25
26
27
28
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
29
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
30
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        self.set_init_device()
32

33
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
34
        logger.info("Initializing runner modules...")
35
36
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
37
38
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
39
40
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
41
42
43
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
44
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
45
46
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
47

48
    def set_init_device(self):
49
        if self.config.cpu_offload:
50
            self.init_device = torch.device("cpu")
51
        else:
52
            self.init_device = torch.device("cuda")
53

PengGao's avatar
PengGao committed
54
55
56
57
58
59
60
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
61
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
62

63
64
    @ProfilingContext("Load models")
    def load_model(self):
65
66
67
68
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
69
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
90
91
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
92
        self.config["use_prompt_enhancer"] = False
93
        if self.has_prompt_enhancer:
94
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
95
96
97
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
98
99
100
101
102
103
104
105
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
106

PengGao's avatar
PengGao committed
107
108
109
    def set_progress_callback(self, callback):
        self.progress_callback = callback

helloyongyang's avatar
helloyongyang committed
110
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
111
112
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
113
114
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
115
116
117
118

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

helloyongyang's avatar
helloyongyang committed
119
            with ProfilingContext4Debug("🚀 infer_main"):
120
121
122
123
124
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
125
            if self.progress_callback:
126
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
127

128
129
        return self.model.scheduler.latents, self.model.scheduler.generator

helloyongyang's avatar
helloyongyang committed
130
    def run_step(self):
131
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
132
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
133
134

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
135
136
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
137
138
139
140
141
142
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
143
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
144
        torch.cuda.empty_cache()
145
        gc.collect()
helloyongyang's avatar
helloyongyang committed
146

helloyongyang's avatar
helloyongyang committed
147
    def read_image_input(self, img_path):
148
149
150
        img_ori = Image.open(img_path).convert("RGB")
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
151

152
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
153
    def _run_input_encoder_local_i2v(self):
154
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
155
        img, img_ori = self.read_image_input(self.config["image_path"])
helloyongyang's avatar
helloyongyang committed
156
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
157
        vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
158
        text_encoder_output = self.run_text_encoder(prompt, img)
159
160
        torch.cuda.empty_cache()
        gc.collect()
161
162
163
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
164
    def _run_input_encoder_local_t2v(self):
165
166
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
167
168
        torch.cuda.empty_cache()
        gc.collect()
169
170
171
172
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
173

gushiqiao's avatar
gushiqiao committed
174
175
176
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
gushiqiao's avatar
gushiqiao committed
177
178
        first_frame, _ = self.read_image_input(self.config["image_path"])
        last_frame, _ = self.read_image_input(self.config["last_frame_path"])
gushiqiao's avatar
gushiqiao committed
179
180
181
182
183
184
185
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

gushiqiao's avatar
gushiqiao committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_vace(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        src_video = self.config.get("src_video", None)
        src_mask = self.config.get("src_mask", None)
        src_ref_images = self.config.get("src_ref_images", None)
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config.target_width, self.config.target_height),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        text_encoder_output = self.run_text_encoder(prompt)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

helloyongyang's avatar
helloyongyang committed
206
207
208
    def init_run(self):
        self.set_target_shape()
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
209
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
210
            self.model = self.load_transformer()
211
212
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
213
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
214
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
215
216
217
218
219

    @ProfilingContext("Run DiT")
    def run_main(self, total_steps=None):
        self.init_run()
        for segment_idx in range(self.video_segment_num):
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
220
221
222
223
224
225
226
227
228
            with ProfilingContext4Debug(f"segment end2end {segment_idx}"):
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
                latents, generator = self.run_segment(total_steps=total_steps)
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
                self.end_run_segment()
229
230
231
        self.end_run()

    @ProfilingContext("Run VAE Decoder")
gushiqiao's avatar
gushiqiao committed
232
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
233
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
234
            self.vae_decoder = self.load_vae_decoder()
235
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
236
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
237
            del self.vae_decoder
238
239
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
240
241
        return images

242
243
244
245
246
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
247
248
249
250
251
252
253
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
254
255
256
257
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

helloyongyang's avatar
helloyongyang committed
258
259
    def process_images_after_vae_decoder(self, save_video=True):
        self.gen_video = vae_to_comfyui_image(self.gen_video)
PengGao's avatar
PengGao committed
260
261
262
263
264

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
helloyongyang's avatar
helloyongyang committed
265
266
            self.gen_video = self.vfi_model.interpolate_frames(
                self.gen_video,
PengGao's avatar
PengGao committed
267
268
269
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
270

271
        if save_video:
PengGao's avatar
PengGao committed
272
273
274
275
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
276

277
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
278
                logger.info(f"🎬 Start to save video 🎬")
279

helloyongyang's avatar
helloyongyang committed
280
                save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
helloyongyang's avatar
helloyongyang committed
281
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
helloyongyang's avatar
helloyongyang committed
282
        return {"video": self.gen_video}
PengGao's avatar
PengGao committed
283

helloyongyang's avatar
helloyongyang committed
284
285
286
287
288
289
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()

helloyongyang's avatar
helloyongyang committed
290
        self.run_main()
helloyongyang's avatar
helloyongyang committed
291

helloyongyang's avatar
helloyongyang committed
292
        gen_video = self.process_images_after_vae_decoder(save_video=save_video)
helloyongyang's avatar
helloyongyang committed
293

294
295
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
296

helloyongyang's avatar
helloyongyang committed
297
        return gen_video