default_runner.py 13.1 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
PengGao's avatar
PengGao committed
6
7
8
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
9

helloyongyang's avatar
helloyongyang committed
10
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
11
12
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
14

PengGao's avatar
PengGao committed
15
from .base_runner import BaseRunner
16
17


PengGao's avatar
PengGao committed
18
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
19
    def __init__(self, config):
PengGao's avatar
PengGao committed
20
        super().__init__(config)
21
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
22
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
23
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
24
25
26
27
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
28
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
29
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
30
        self.set_init_device()
31

32
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
33
        logger.info("Initializing runner modules...")
34
35
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
36
37
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
38
39
40
41
        self.run_dit = self._run_dit_local
        self.run_vae_decoder = self._run_vae_decoder_local
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
42
43
44
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
45
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
46
47
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
48

49
    def set_init_device(self):
50
        if self.config.cpu_offload:
51
            self.init_device = torch.device("cpu")
52
        else:
53
            self.init_device = torch.device("cuda")
54

PengGao's avatar
PengGao committed
55
56
57
58
59
60
61
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
62
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
63

64
65
    @ProfilingContext("Load models")
    def load_model(self):
66
67
68
69
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
70
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
91
92
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
93
        self.config["use_prompt_enhancer"] = False
94
        if self.has_prompt_enhancer:
95
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
96
97
98
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
99
100
101
102
103
104
105
106
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
107

PengGao's avatar
PengGao committed
108
109
110
    def set_progress_callback(self, callback):
        self.progress_callback = callback

helloyongyang's avatar
helloyongyang committed
111
112
113
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
114
115
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
116
117
118
119

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

helloyongyang's avatar
helloyongyang committed
120
            with ProfilingContext4Debug("🚀 infer_main"):
121
122
123
124
125
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
126
            if self.progress_callback:
127
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
128

129
130
        return self.model.scheduler.latents, self.model.scheduler.generator

helloyongyang's avatar
helloyongyang committed
131
    def run_step(self):
132
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
133
134
        self.set_target_shape()
        self.run_dit(total_steps=1)
helloyongyang's avatar
helloyongyang committed
135
136

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
137
138
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
139
140
141
142
143
144
145
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
            self.model.post_weight.clear()
146
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
147
        torch.cuda.empty_cache()
148
        gc.collect()
helloyongyang's avatar
helloyongyang committed
149

150
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
151
    def _run_input_encoder_local_i2v(self):
152
153
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
helloyongyang's avatar
helloyongyang committed
154
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
helloyongyang's avatar
helloyongyang committed
155
        vae_encode_out = self.run_vae_encoder(img)
156
        text_encoder_output = self.run_text_encoder(prompt, img)
157
158
        torch.cuda.empty_cache()
        gc.collect()
159
160
161
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
162
    def _run_input_encoder_local_t2v(self):
163
164
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
165
166
        torch.cuda.empty_cache()
        gc.collect()
167
168
169
170
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
171

gushiqiao's avatar
gushiqiao committed
172
173
174
175
176
177
178
179
180
181
182
183
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        first_frame = Image.open(self.config["image_path"]).convert("RGB")
        last_frame = Image.open(self.config["last_frame_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

gushiqiao's avatar
gushiqiao committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_vace(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        src_video = self.config.get("src_video", None)
        src_mask = self.config.get("src_mask", None)
        src_ref_images = self.config.get("src_ref_images", None)
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config.target_width, self.config.target_height),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        text_encoder_output = self.run_text_encoder(prompt)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

204
    @ProfilingContext("Run DiT")
helloyongyang's avatar
helloyongyang committed
205
    def _run_dit_local(self, total_steps=None):
gushiqiao's avatar
gushiqiao committed
206
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
207
            self.model = self.load_transformer()
208
209
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
210
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
211
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
212
        latents, generator = self.run(total_steps)
213
214
215
216
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
PengGao's avatar
PengGao committed
217
    def _run_vae_decoder_local(self, latents, generator):
gushiqiao's avatar
gushiqiao committed
218
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
219
            self.vae_decoder = self.load_vae_decoder()
220
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gushiqiao's avatar
gushiqiao committed
221
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
222
            del self.vae_decoder
223
224
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
225
226
        return images

227
228
229
230
231
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
232
233
234
235
236
237
238
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
239
240
241
242
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

helloyongyang's avatar
helloyongyang committed
243
    def process_images_after_vae_decoder(self, images, save_video=True):
helloyongyang's avatar
helloyongyang committed
244
        images = vae_to_comfyui_image(images)
PengGao's avatar
PengGao committed
245
246
247
248
249
250
251
252
253
254

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
            images = self.vfi_model.interpolate_frames(
                images,
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
255

256
        if save_video:
PengGao's avatar
PengGao committed
257
258
259
260
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
261

262
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
263
                logger.info(f"🎬 Start to save video 🎬")
264

helloyongyang's avatar
helloyongyang committed
265
                save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg")
helloyongyang's avatar
helloyongyang committed
266
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
PengGao's avatar
PengGao committed
267

helloyongyang's avatar
helloyongyang committed
268
269
270
271
272
273
274
275
276
277
278
279
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()
        self.set_target_shape()

        latents, generator = self.run_dit()

        images = self.run_vae_decoder(latents, generator)
        self.process_images_after_vae_decoder(images, save_video=save_video)

280
281
282
        del latents, generator
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
283

284
285
        # Return (images, audio) - audio is None for default runner
        return images, None