"docs/backend/backend.md" did not exist on "282681b8a15affd7f7d9e16584c38954ba4e8413"
default_runner.py 15 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
PengGao's avatar
PengGao committed
2

3
import requests
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
helloyongyang's avatar
helloyongyang committed
6
import torchvision.transforms.functional as TF
PengGao's avatar
PengGao committed
7
8
9
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException
PengGao's avatar
PengGao committed
10

helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
12
from lightx2v.utils.generate_task_id import generate_task_id
13
from lightx2v.utils.memory_profiler import peak_memory_decorator
14
from lightx2v.utils.profiler import *
helloyongyang's avatar
helloyongyang committed
15
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
PengGao's avatar
PengGao committed
16

PengGao's avatar
PengGao committed
17
from .base_runner import BaseRunner
18
19


PengGao's avatar
PengGao committed
20
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
21
    def __init__(self, config):
PengGao's avatar
PengGao committed
22
        super().__init__(config)
23
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
24
        self.progress_callback = None
Rongjin Yang's avatar
Rongjin Yang committed
25
        if self.config.task == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
26
27
28
29
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
30
        if not self.has_prompt_enhancer:
Rongjin Yang's avatar
Rongjin Yang committed
31
            self.config.use_prompt_enhancer = False
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
32
        self.set_init_device()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
33
        self.init_scheduler()
34

35
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
36
        logger.info("Initializing runner modules...")
37
38
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
39
40
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
41
        self.model.set_scheduler(self.scheduler)  # set scheduler to model
42
43
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
gushiqiao's avatar
gushiqiao committed
44
45
46
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
47
            self.run_input_encoder = self._run_input_encoder_local_t2v
gushiqiao's avatar
gushiqiao committed
48
49
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
50
51
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
52
53
54
        if self.config.get("compile", False):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))
55

56
    def set_init_device(self):
57
        if self.config.cpu_offload:
58
            self.init_device = torch.device("cpu")
59
        else:
60
            self.init_device = torch.device("cuda")
61

PengGao's avatar
PengGao committed
62
63
64
65
66
67
68
    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
69
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")
PengGao's avatar
PengGao committed
70

71
    @ProfilingContext4DebugL2("Load models")
72
    def load_model(self):
73
74
75
76
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
PengGao's avatar
PengGao committed
77
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
98
99
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
100
        self.config["use_prompt_enhancer"] = False
101
        if self.has_prompt_enhancer:
102
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
103
104
105
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
106
107
108
109
110
111
112
113
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
114

PengGao's avatar
PengGao committed
115
116
117
    def set_progress_callback(self, callback):
        self.progress_callback = callback

118
    @peak_memory_decorator
helloyongyang's avatar
helloyongyang committed
119
    def run_segment(self, total_steps=None):
helloyongyang's avatar
helloyongyang committed
120
121
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
PengGao's avatar
PengGao committed
122
        for step_index in range(total_steps):
LiangLiu's avatar
LiangLiu committed
123
124
125
            # only for single segment, check stop signal every step
            if self.video_segment_num == 1:
                self.check_stop()
PengGao's avatar
PengGao committed
126
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
127

128
            with ProfilingContext4DebugL1("step_pre"):
129
130
                self.model.scheduler.step_pre(step_index=step_index)

131
            with ProfilingContext4DebugL1("🚀 infer_main"):
132
133
                self.model.infer(self.inputs)

134
            with ProfilingContext4DebugL1("step_post"):
135
136
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
137
            if self.progress_callback:
138
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
PengGao's avatar
PengGao committed
139

helloyongyang's avatar
helloyongyang committed
140
        return self.model.scheduler.latents
141

helloyongyang's avatar
helloyongyang committed
142
    def run_step(self):
143
        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
144
        self.run_main(total_steps=1)
helloyongyang's avatar
helloyongyang committed
145
146

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
147
        self.model.scheduler.clear()
LiangLiu's avatar
LiangLiu committed
148
        del self.inputs
gushiqiao's avatar
gushiqiao committed
149
150
151
152
153
154
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
155
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
156
        torch.cuda.empty_cache()
157
        gc.collect()
helloyongyang's avatar
helloyongyang committed
158

helloyongyang's avatar
helloyongyang committed
159
    def read_image_input(self, img_path):
LiangLiu's avatar
LiangLiu committed
160
161
162
163
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
164
165
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
        return img, img_ori
helloyongyang's avatar
helloyongyang committed
166

167
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
168
    def _run_input_encoder_local_i2v(self):
169
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
170
        img, img_ori = self.read_image_input(self.config["image_path"])
helloyongyang's avatar
helloyongyang committed
171
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
172
        vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
173
        text_encoder_output = self.run_text_encoder(prompt, img)
174
175
        torch.cuda.empty_cache()
        gc.collect()
176
177
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

178
    @ProfilingContext4DebugL2("Run Encoders")
PengGao's avatar
PengGao committed
179
    def _run_input_encoder_local_t2v(self):
180
181
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
182
183
        torch.cuda.empty_cache()
        gc.collect()
184
185
186
187
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
188

189
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
190
191
    def _run_input_encoder_local_flf2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
gushiqiao's avatar
gushiqiao committed
192
193
        first_frame, _ = self.read_image_input(self.config["image_path"])
        last_frame, _ = self.read_image_input(self.config["last_frame_path"])
gushiqiao's avatar
gushiqiao committed
194
195
196
197
198
199
200
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
        text_encoder_output = self.run_text_encoder(prompt, first_frame)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

201
    @ProfilingContext4DebugL2("Run Encoders")
gushiqiao's avatar
gushiqiao committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    def _run_input_encoder_local_vace(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        src_video = self.config.get("src_video", None)
        src_mask = self.config.get("src_mask", None)
        src_ref_images = self.config.get("src_ref_images", None)
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config.target_width, self.config.target_height),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        text_encoder_output = self.run_text_encoder(prompt)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

221
222
223
224
225
226
227
228
    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
        torch.cuda.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

helloyongyang's avatar
helloyongyang committed
229
230
231
    def init_run(self):
        self.set_target_shape()
        self.get_video_segment_num()
gushiqiao's avatar
gushiqiao committed
232
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
233
            self.model = self.load_transformer()
234
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
gushiqiao's avatar
gushiqiao committed
235
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
236
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None
helloyongyang's avatar
helloyongyang committed
237

238
    @ProfilingContext4DebugL2("Run DiT")
helloyongyang's avatar
helloyongyang committed
239
240
    def run_main(self, total_steps=None):
        self.init_run()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
241
242
        if self.config.get("compile", False):
            self.model.select_graph_for_compile()
helloyongyang's avatar
helloyongyang committed
243
        for segment_idx in range(self.video_segment_num):
244
245
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
            with ProfilingContext4DebugL1(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
LiangLiu's avatar
LiangLiu committed
246
                self.check_stop()
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
247
248
249
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
helloyongyang's avatar
helloyongyang committed
250
                latents = self.run_segment(total_steps=total_steps)
Yang Yong(雍洋)'s avatar
Yang Yong(雍洋) committed
251
252
253
                # 3. vae decoder
                self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
254
                self.end_run_segment(segment_idx)
255
256
        self.end_run()

257
    @ProfilingContext4DebugL1("Run VAE Decoder")
gushiqiao's avatar
gushiqiao committed
258
    def run_vae_decoder(self, latents):
gushiqiao's avatar
gushiqiao committed
259
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
260
            self.vae_decoder = self.load_vae_decoder()
261
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
gushiqiao's avatar
gushiqiao committed
262
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
263
            del self.vae_decoder
264
265
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
266
267
        return images

268
269
270
271
272
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
273
274
275
276
277
278
279
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
280
281
282
283
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

helloyongyang's avatar
helloyongyang committed
284
285
    def process_images_after_vae_decoder(self, save_video=True):
        self.gen_video = vae_to_comfyui_image(self.gen_video)
PengGao's avatar
PengGao committed
286
287
288
289
290

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
helloyongyang's avatar
helloyongyang committed
291
292
            self.gen_video = self.vfi_model.interpolate_frames(
                self.gen_video,
PengGao's avatar
PengGao committed
293
294
295
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )
PengGao's avatar
PengGao committed
296

297
        if save_video:
PengGao's avatar
PengGao committed
298
299
300
301
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)
helloyongyang's avatar
helloyongyang committed
302

303
            if not dist.is_initialized() or dist.get_rank() == 0:
helloyongyang's avatar
helloyongyang committed
304
                logger.info(f"🎬 Start to save video 🎬")
305

helloyongyang's avatar
helloyongyang committed
306
                save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
helloyongyang's avatar
helloyongyang committed
307
                logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
LiangLiu's avatar
LiangLiu committed
308
309
310
        if self.config.get("return_video", False):
            return {"video": self.gen_video}
        return {"video": None}
PengGao's avatar
PengGao committed
311

helloyongyang's avatar
helloyongyang committed
312
313
314
315
316
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
317
        self.run_main()
helloyongyang's avatar
helloyongyang committed
318

helloyongyang's avatar
helloyongyang committed
319
        gen_video = self.process_images_after_vae_decoder(save_video=save_video)
helloyongyang's avatar
helloyongyang committed
320

321
322
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
323

helloyongyang's avatar
helloyongyang committed
324
        return gen_video