default_runner.py 17.2 KB
Newer Older
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
import gc
3
import aiohttp
4
5
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
6
7
import torch
import torch.distributed as dist
8
9
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
12
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
15
from loguru import logger
PengGao's avatar
PengGao committed
16
17
from .async_wrapper import AsyncWrapper
from .base_runner import BaseRunner
18
19


PengGao's avatar
PengGao committed
20
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
21
    def __init__(self, config):
PengGao's avatar
PengGao committed
22
        super().__init__(config)
23
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
24
        self.progress_callback = None
25
26
27
28
29
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
30
31
32
        if not self.has_prompt_enhancer:
            self.config["use_prompt_enhancer"] = False
        self.set_init_device()
33

34
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
35
        logger.info("Initializing runner modules...")
36
37
38
        if self.config["mode"] == "split_server":
            self.tensor_transporter = TensorTransporter()
            self.image_transporter = ImageTransporter()
39
40
            if not self.check_sub_servers("dit"):
                raise ValueError("No dit server available")
41
42
            if not self.check_sub_servers("text_encoders"):
                raise ValueError("No text encoder server available")
43
44
45
            if self.config["task"] == "i2v":
                if not self.check_sub_servers("image_encoder"):
                    raise ValueError("No image encoder server available")
46
            if not self.check_sub_servers("vae_model"):
47
48
                raise ValueError("No vae server available")
        else:
gushiqiao's avatar
gushiqiao committed
49
            if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
50
                self.load_model()
51

52
    def set_init_device(self):
53
54
55
56
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
57
            self.init_device = torch.device("cpu")
58
        else:
59
            self.init_device = torch.device("cuda")
60
61
62

    @ProfilingContext("Load models")
    def load_model(self):
63
64
65
66
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
87
88
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
89
        self.config["use_prompt_enhancer"] = False
90
        if self.has_prompt_enhancer:
91
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
92
93
94
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
95
96
97
98
99
100
101
102
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
103

PengGao's avatar
PengGao committed
104
105
106
    def set_progress_callback(self, callback):
        self.progress_callback = callback

107
    def run(self):
PengGao's avatar
PengGao committed
108
109
110
        total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
111
112
113
114
115
116
117
118
119
120

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
121
122
123
            if self.progress_callback:
                self.progress_callback(step_index + 1, total_steps)

124
125
        return self.model.scheduler.latents, self.model.scheduler.generator

126
    async def run_step(self, step_index=0):
PengGao's avatar
PengGao committed
127
128
129
130
131
132
133
        async with AsyncWrapper(self) as wrapper:
            self.init_scheduler()
            self.inputs = await wrapper.run_input_encoder()
            self.model.scheduler.prepare(self.inputs["image_encoder_output"])
            self.model.scheduler.step_pre(step_index=step_index)
            self.model.infer(self.inputs)
            self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
134
135

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
136
137
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
138
139
140
141
142
143
144
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
            self.model.post_weight.clear()
145
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
146
        torch.cuda.empty_cache()
147
        gc.collect()
helloyongyang's avatar
helloyongyang committed
148

149
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
150
    def _run_input_encoder_local_i2v(self):
151
152
153
154
155
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
        vae_encode_out, kwargs = self.run_vae_encoder(img)
        text_encoder_output = self.run_text_encoder(prompt, img)
156
157
        torch.cuda.empty_cache()
        gc.collect()
158
159
160
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
161
    def _run_input_encoder_local_t2v(self):
162
163
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
164
165
        torch.cuda.empty_cache()
        gc.collect()
166
167
168
169
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
170
171

    @ProfilingContext("Run DiT")
PengGao's avatar
PengGao committed
172
    def _run_dit_local(self, kwargs):
gushiqiao's avatar
gushiqiao committed
173
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
174
            self.model = self.load_transformer()
175
176
177
178
179
180
181
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
PengGao's avatar
PengGao committed
182
    def _run_vae_decoder_local(self, latents, generator):
gushiqiao's avatar
gushiqiao committed
183
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
184
            self.vae_decoder = self.load_vae_decoder()
185
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gushiqiao's avatar
gushiqiao committed
186
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
187
            del self.vae_decoder
188
189
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
190
191
192
193
194
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
195
            self.save_video_func(images)
helloyongyang's avatar
helloyongyang committed
196

PengGao's avatar
PengGao committed
197
198
    async def post_task(self, task_type, urls, message, device="cuda", max_retries=3, timeout=30):
        for attempt in range(max_retries):
199
            for url in urls:
PengGao's avatar
PengGao committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
                try:
                    async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
                        try:
                            async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
                                if response.status != 200:
                                    logger.warning(f"Service {url} returned status {response.status}")
                                    continue
                                status = await response.json()
                        except asyncio.TimeoutError:
                            logger.warning(f"Timeout checking status for {url}")
                            continue
                        except Exception as e:
                            logger.warning(f"Error checking status for {url}: {e}")
                            continue

                        if status.get("service_status") == "idle":
                            try:
                                async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
                                    if response.status == 200:
                                        result = await response.json()
                                        if result.get("kwargs") is not None:
                                            for k, v in result["kwargs"].items():
                                                setattr(self.config, k, v)
                                        return self.tensor_transporter.load_tensor(result["output"], device)
                                    else:
                                        logger.warning(f"Task failed with status {response.status} for {url}")
                            except asyncio.TimeoutError:
                                logger.warning(f"Timeout posting task to {url}")
                            except Exception as e:
                                logger.error(f"Error posting task to {url}: {e}")

                except aiohttp.ClientError as e:
                    logger.warning(f"Client error for {url}: {e}")
                except Exception as e:
                    logger.error(f"Unexpected error for {url}: {e}")

            if attempt < max_retries - 1:
                wait_time = min(2**attempt, 10)
                logger.info(f"Retrying in {wait_time} seconds... (attempt {attempt + 1}/{max_retries})")
                await asyncio.sleep(wait_time)

        raise RuntimeError(f"Failed to complete task {task_type} after {max_retries} attempts")
242

243
244
245
246
247
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
248
249
250
251
252
253
254
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
255
256
257
258
259
260
261
262
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

    async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False):
        tasks = []
        img_byte = self.image_transporter.prepare_image(img)
        tasks.append(
263
264
265
266
267
268
269
270
            asyncio.create_task(
                self.post_task(
                    task_type="image_encoder",
                    urls=self.config["sub_servers"]["image_encoder"],
                    message={"task_id": generate_task_id(), "img": img_byte},
                    device="cuda",
                )
            )
271
272
        )
        tasks.append(
273
274
275
276
277
278
279
280
            asyncio.create_task(
                self.post_task(
                    task_type="vae_model/encoder",
                    urls=self.config["sub_servers"]["vae_model"],
                    message={"task_id": generate_task_id(), "img": img_byte},
                    device="cuda",
                )
            )
281
282
283
284
285
286
        )
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
287
288
289
290
291
292
                    message={
                        "task_id": generate_task_id(),
                        "text": prompt,
                        "img": img_byte,
                        "n_prompt": n_prompt,
                    },
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # clip_encoder, vae_encoder, text_encoders
        return results[0], results[1], results[2]

    async def post_encoders_t2v(self, prompt, n_prompt=None):
        tasks = []
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
308
309
310
311
312
313
                    message={
                        "task_id": generate_task_id(),
                        "text": prompt,
                        "img": None,
                        "n_prompt": n_prompt,
                    },
314
315
316
317
318
319
320
321
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # text_encoders
        return results[0]

PengGao's avatar
PengGao committed
322
    async def _run_input_encoder_server_i2v(self):
323
324
325
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        img = Image.open(self.config["image_path"]).convert("RGB")
326
327
328
329
330
        (
            clip_encoder_out,
            vae_encode_out,
            text_encoder_output,
        ) = await self.post_encoders_i2v(prompt, img, n_prompt)
331
332
        torch.cuda.empty_cache()
        gc.collect()
333
334
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

PengGao's avatar
PengGao committed
335
    async def _run_input_encoder_server_t2v(self):
336
337
338
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
339
340
        torch.cuda.empty_cache()
        gc.collect()
341
342
343
344
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
345

PengGao's avatar
PengGao committed
346
    async def _run_dit_server(self, kwargs):
347
348
349
350
351
        if self.inputs.get("image_encoder_output", None) is not None:
            self.inputs["image_encoder_output"].pop("img", None)
        dit_output = await self.post_task(
            task_type="dit",
            urls=self.config["sub_servers"]["dit"],
352
353
354
355
356
            message={
                "task_id": generate_task_id(),
                "inputs": self.tensor_transporter.prepare_tensor(self.inputs),
                "kwargs": self.tensor_transporter.prepare_tensor(kwargs),
            },
357
358
359
360
            device="cuda",
        )
        return dit_output, None

PengGao's avatar
PengGao committed
361
    async def _run_vae_decoder_server(self, latents, generator):
362
363
364
        images = await self.post_task(
            task_type="vae_model/decoder",
            urls=self.config["sub_servers"]["vae_model"],
365
366
367
368
            message={
                "task_id": generate_task_id(),
                "latents": self.tensor_transporter.prepare_tensor(latents),
            },
369
370
371
372
            device="cpu",
        )
        return images

PengGao's avatar
PengGao committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    async def run_pipeline(self, save_video=True):
        async with AsyncWrapper(self) as wrapper:
            if self.config["use_prompt_enhancer"]:
                self.config["prompt_enhanced"] = await wrapper.run_prompt_enhancer()

            self.inputs = await wrapper.run_input_encoder()

            kwargs = self.set_target_shape()

            latents, generator = await wrapper.run_dit(kwargs)

            images = await wrapper.run_vae_decoder(latents, generator)

            if save_video:
                await wrapper.save_video(images)

            del latents, generator
            torch.cuda.empty_cache()
            gc.collect()

            return images