default_runner.py 12.2 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
PengGao's avatar
PengGao committed
6
7
8
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
9

helloyongyang's avatar
helloyongyang committed
10
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
11
12
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
PengGao's avatar
PengGao committed
13
from lightx2v.utils.utils import cache_video, save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
14

PengGao's avatar
PengGao committed
15
from .base_runner import BaseRunner
16
17


PengGao's avatar
PengGao committed
18
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
19
    def __init__(self, config):
PengGao's avatar
PengGao committed
20
        super().__init__(config)
21
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
22
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
23
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
24
25
26
27
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
28
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
29
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
30
        self.set_init_device()
31

32
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
33
        logger.info("Initializing runner modules...")
34
35
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
36
37
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
38
39
40
41
        self.run_dit = self._run_dit_local
        self.run_vae_decoder = self._run_vae_decoder_local
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
42
43
44
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
45
            self.run_input_encoder = self._run_input_encoder_local_t2v
46

47
    def set_init_device(self):
48
        if self.config.cpu_offload:
49
            self.init_device = torch.device("cpu")
50
        else:
51
            self.init_device = torch.device("cuda")
52

PengGao's avatar
PengGao committed
53
54
55
56
57
58
59
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
60
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
61

62
63
    @ProfilingContext("Load models")
    def load_model(self):
64
65
66
67
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
68
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
89
90
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
91
        self.config["use_prompt_enhancer"] = False
92
        if self.has_prompt_enhancer:
93
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
94
95
96
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
97
98
99
100
101
102
103
104
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
105

PengGao's avatar
PengGao committed
106
107
108
    def set_progress_callback(self, callback):
        self.progress_callback = callback

helloyongyang's avatar
helloyongyang committed
109
110
111
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
112
113
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
114
115
116
117

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

helloyongyang's avatar
helloyongyang committed
118
            with ProfilingContext4Debug("🚀 infer_main"):
119
120
121
122
123
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
124
            if self.progress_callback:
125
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
126

127
128
        return self.model.scheduler.latents, self.model.scheduler.generator

helloyongyang's avatar
helloyongyang committed
129
    def run_step(self):
130
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
131
132
        self.set_target_shape()
        self.run_dit(total_steps=1)
helloyongyang's avatar
helloyongyang committed
133
134

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
135
136
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
137
138
139
140
141
142
143
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
            self.model.post_weight.clear()
144
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
145
        torch.cuda.empty_cache()
146
        gc.collect()
helloyongyang's avatar
helloyongyang committed
147

148
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
149
    def _run_input_encoder_local_i2v(self):
150
151
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
helloyongyang's avatar
helloyongyang committed
152
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
helloyongyang's avatar
helloyongyang committed
153
        vae_encode_out = self.run_vae_encoder(img)
154
        text_encoder_output = self.run_text_encoder(prompt, img)
155
156
        torch.cuda.empty_cache()
        gc.collect()
157
158
159
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
160
    def _run_input_encoder_local_t2v(self):
161
162
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
163
164
        torch.cuda.empty_cache()
        gc.collect()
165
166
167
168
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
169

gushiqiao's avatar
gushiqiao committed
170
171
172
173
174
175
176
177
178
179
180
181
    @ProfilingContext("Run Encoders")
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        first_frame = Image.open(self.config["image_path"]).convert("RGB")
        last_frame = Image.open(self.config["last_frame_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

182
    @ProfilingContext("Run DiT")
helloyongyang's avatar
helloyongyang committed
183
    def _run_dit_local(self, total_steps=None):
gushiqiao's avatar
gushiqiao committed
184
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
185
            self.model = self.load_transformer()
186
187
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
188
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
189
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
190
        latents, generator = self.run(total_steps)
191
192
193
194
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
PengGao's avatar
PengGao committed
195
    def _run_vae_decoder_local(self, latents, generator):
gushiqiao's avatar
gushiqiao committed
196
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
197
            self.vae_decoder = self.load_vae_decoder()
198
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gushiqiao's avatar
gushiqiao committed
199
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
200
            del self.vae_decoder
201
202
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
203
204
        return images

205
206
207
208
209
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
210
211
212
213
214
215
216
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
217
218
219
220
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

221
222
223
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
PengGao's avatar
PengGao committed
224

225
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
226
227
        self.set_target_shape()
        latents, generator = self.run_dit()
PengGao's avatar
PengGao committed
228

229
        images = self.run_vae_decoder(latents, generator)
230
231
        if self.config["model_cls"] != "wan2.2":
            images = vae_to_comfyui_image(images)
PengGao's avatar
PengGao committed
232
233
234
235
236
237
238
239
240
241

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
            images = self.vfi_model.interpolate_frames(
                images,
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
242

243
        if save_video:
PengGao's avatar
PengGao committed
244
245
246
247
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
248

249
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
250
                logger.info(f"🎬 Start to save video 🎬")
251
252
253
254
255

                if self.config["model_cls"] != "wan2.2":
                    save_to_video(images, self.config.save_video_path, fps=fps, method="ffmpeg")  # type: ignore
                else:
                    cache_video(tensor=images, save_file=self.config.save_video_path, fps=fps, nrow=1, normalize=True, value_range=(-1, 1))
helloyongyang's avatar
helloyongyang committed
256
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
PengGao's avatar
PengGao committed
257

258
259
260
        del latents, generator
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
261

262
263
        # Return (images, audio) - audio is None for default runner
        return images, None