default_runner.py 14.6 KB
Newer Older
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
import gc
3
import aiohttp
4
5
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
6
7
import torch
import torch.distributed as dist
8
9
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
12
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
15
from loguru import logger
16
17
18


class DefaultRunner:
helloyongyang's avatar
helloyongyang committed
19
20
    def __init__(self, config):
        self.config = config
21
        self.has_prompt_enhancer = False
22
23
24
25
26
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
27
28
29
        if not self.has_prompt_enhancer:
            self.config["use_prompt_enhancer"] = False
        self.set_init_device()
30

31
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
32
        logger.info("Initializing runner modules...")
33
34
35
        if self.config["mode"] == "split_server":
            self.tensor_transporter = TensorTransporter()
            self.image_transporter = ImageTransporter()
36
37
            if not self.check_sub_servers("dit"):
                raise ValueError("No dit server available")
38
39
            if not self.check_sub_servers("text_encoders"):
                raise ValueError("No text encoder server available")
40
41
42
            if self.config["task"] == "i2v":
                if not self.check_sub_servers("image_encoder"):
                    raise ValueError("No image encoder server available")
43
            if not self.check_sub_servers("vae_model"):
44
45
46
47
48
49
50
51
                raise ValueError("No vae server available")
            self.run_dit = self.run_dit_server
            self.run_vae_decoder = self.run_vae_decoder_server
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_server_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_server_t2v
        else:
gushiqiao's avatar
gushiqiao committed
52
            if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
53
                self.load_model()
54
55
56
57
58
59
60
            self.run_dit = self.run_dit_local
            self.run_vae_decoder = self.run_vae_decoder_local
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_local_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_local_t2v

61
    def set_init_device(self):
62
63
64
65
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
66
            self.init_device = torch.device("cpu")
67
        else:
68
            self.init_device = torch.device("cuda")
69
70
71

    @ProfilingContext("Load models")
    def load_model(self):
72
73
74
75
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
96
97
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
98
        self.config["use_prompt_enhancer"] = False
99
        if self.has_prompt_enhancer:
100
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
101
102
103
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
104
105
106
107
108
109
110
111
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
112

113
114
    def run(self):
        for step_index in range(self.model.scheduler.infer_steps):
115
            logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
116
117
118
119
120
121
122
123
124
125
126
127

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

        return self.model.scheduler.latents, self.model.scheduler.generator

128
    async def run_step(self, step_index=0):
helloyongyang's avatar
helloyongyang committed
129
        self.init_scheduler()
130
        await self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
131
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
132
133
134
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
135
136

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
137
138
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
139
140
141
142
143
144
145
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
            self.model.post_weight.clear()
146
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
147
        torch.cuda.empty_cache()
148
        gc.collect()
helloyongyang's avatar
helloyongyang committed
149

150
151
152
153
154
155
156
    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
        vae_encode_out, kwargs = self.run_vae_encoder(img)
        text_encoder_output = self.run_text_encoder(prompt, img)
157
158
        torch.cuda.empty_cache()
        gc.collect()
159
160
161
162
163
164
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
165
166
        torch.cuda.empty_cache()
        gc.collect()
167
168
169
170
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    @ProfilingContext("Run DiT")
    async def run_dit_local(self, kwargs):
gushiqiao's avatar
gushiqiao committed
171
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
172
            self.model = self.load_transformer()
173
174
175
176
177
178
179
180
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
    async def run_vae_decoder_local(self, latents, generator):
gushiqiao's avatar
gushiqiao committed
181
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
182
            self.vae_decoder = self.load_vae_decoder()
183
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gushiqiao's avatar
gushiqiao committed
184
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
185
            del self.vae_decoder
186
187
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
188
189
190
191
192
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
193
            self.save_video_func(images)
helloyongyang's avatar
helloyongyang committed
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    async def post_task(self, task_type, urls, message, device="cuda"):
        while True:
            for url in urls:
                async with aiohttp.ClientSession() as session:
                    async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
                        status = await response.json()
                    if status["service_status"] == "idle":
                        async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
                            result = await response.json()
                            if result["kwargs"] is not None:
                                for k, v in result["kwargs"].items():
                                    setattr(self.config, k, v)
                            return self.tensor_transporter.load_tensor(result["output"], device)
            await asyncio.sleep(0.1)

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
                    response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

    async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False):
        tasks = []
        img_byte = self.image_transporter.prepare_image(img)
        tasks.append(
            asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # clip_encoder, vae_encoder, text_encoders
        return results[0], results[1], results[2]

    async def post_encoders_t2v(self, prompt, n_prompt=None):
        tasks = []
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # text_encoders
        return results[0]

    async def run_input_encoder_server_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
264
265
        torch.cuda.empty_cache()
        gc.collect()
266
267
268
269
270
271
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    async def run_input_encoder_server_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
272
273
        torch.cuda.empty_cache()
        gc.collect()
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    async def run_dit_server(self, kwargs):
        if self.inputs.get("image_encoder_output", None) is not None:
            self.inputs["image_encoder_output"].pop("img", None)
        dit_output = await self.post_task(
            task_type="dit",
            urls=self.config["sub_servers"]["dit"],
            message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
            device="cuda",
        )
        return dit_output, None

    async def run_vae_decoder_server(self, latents, generator):
        images = await self.post_task(
            task_type="vae_model/decoder",
            urls=self.config["sub_servers"]["vae_model"],
            message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
            device="cpu",
        )
        return images

296
    async def run_pipeline(self):
helloyongyang's avatar
helloyongyang committed
297
        if self.config["use_prompt_enhancer"]:
298
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
299
300
301
302
        self.inputs = await self.run_input_encoder()
        kwargs = self.set_target_shape()
        latents, generator = await self.run_dit(kwargs)
        images = await self.run_vae_decoder(latents, generator)
helloyongyang's avatar
helloyongyang committed
303
        self.save_video(images)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
304
305
        del latents, generator, images
        torch.cuda.empty_cache()
306
        gc.collect()