default_runner.py 13.2 KB
Newer Older
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
import gc
3
import aiohttp
4
5
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
6
7
import torch
import torch.distributed as dist
8
9
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
12
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
15
from loguru import logger
16
17
18


class DefaultRunner:
helloyongyang's avatar
helloyongyang committed
19
20
    def __init__(self, config):
        self.config = config
21
        self.has_prompt_enhancer = False
22
23
24
25
26
27
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")

28
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
29
        logger.info("Initializing runner modules...")
30
        self.set_init_device()
31
32
33
        if self.config["mode"] == "split_server":
            self.tensor_transporter = TensorTransporter()
            self.image_transporter = ImageTransporter()
34
35
            if not self.check_sub_servers("dit"):
                raise ValueError("No dit server available")
36
37
            if not self.check_sub_servers("text_encoders"):
                raise ValueError("No text encoder server available")
38
39
40
            if self.config["task"] == "i2v":
                if not self.check_sub_servers("image_encoder"):
                    raise ValueError("No image encoder server available")
41
            if not self.check_sub_servers("vae_model"):
42
43
44
45
46
47
48
49
                raise ValueError("No vae server available")
            self.run_dit = self.run_dit_server
            self.run_vae_decoder = self.run_vae_decoder_server
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_server_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_server_t2v
        else:
50
51
            if not self.config.get("lazy_load", False):
                self.load_model()
52
53
54
55
56
57
58
            self.run_dit = self.run_dit_local
            self.run_vae_decoder = self.run_vae_decoder_local
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_local_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_local_t2v

59
    def set_init_device(self):
60
61
62
63
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
64
            self.init_device = torch.device("cpu")
65
        else:
66
            self.init_device = torch.device("cuda")
67
68
69

    @ProfilingContext("Load models")
    def load_model(self):
70
71
72
73
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
74

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
94
95
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
96
        if self.has_prompt_enhancer:
97
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
98
99
100
101
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")

102
103
    def run(self):
        for step_index in range(self.model.scheduler.infer_steps):
104
            logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
105
106
107
108
109
110
111
112
113
114
115
116

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

        return self.model.scheduler.latents, self.model.scheduler.generator

117
    async def run_step(self, step_index=0):
helloyongyang's avatar
helloyongyang committed
118
        self.init_scheduler()
119
        await self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
120
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
121
122
123
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
124
125

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
126
127
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
128
129
        if self.config.get("lazy_load", False):
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
130
        torch.cuda.empty_cache()
131
        gc.collect()
helloyongyang's avatar
helloyongyang committed
132

133
134
135
136
137
138
139
    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
        vae_encode_out, kwargs = self.run_vae_encoder(img)
        text_encoder_output = self.run_text_encoder(prompt, img)
140
141
        torch.cuda.empty_cache()
        gc.collect()
142
143
144
145
146
147
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
148
149
        torch.cuda.empty_cache()
        gc.collect()
150
151
152
153
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    @ProfilingContext("Run DiT")
    async def run_dit_local(self, kwargs):
154
155
        if self.config.get("lazy_load", False):
            self.model = self.load_transformer()
156
157
158
159
160
161
162
163
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
    async def run_vae_decoder_local(self, latents, generator):
164
165
        if self.config.get("lazy_load", False):
            self.vae_decoder = self.load_vae_decoder()
166
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
167
        if self.config.get("lazy_load", False):
gushiqiao's avatar
gushiqiao committed
168
            del self.vae_decoder
169
170
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
171
172
173
174
175
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
176
            self.save_video_func(images)
helloyongyang's avatar
helloyongyang committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    async def post_task(self, task_type, urls, message, device="cuda"):
        while True:
            for url in urls:
                async with aiohttp.ClientSession() as session:
                    async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
                        status = await response.json()
                    if status["service_status"] == "idle":
                        async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
                            result = await response.json()
                            if result["kwargs"] is not None:
                                for k, v in result["kwargs"].items():
                                    setattr(self.config, k, v)
                            return self.tensor_transporter.load_tensor(result["output"], device)
            await asyncio.sleep(0.1)

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
                    response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

    async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False):
        tasks = []
        img_byte = self.image_transporter.prepare_image(img)
        tasks.append(
            asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # clip_encoder, vae_encoder, text_encoders
        return results[0], results[1], results[2]

    async def post_encoders_t2v(self, prompt, n_prompt=None):
        tasks = []
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # text_encoders
        return results[0]

    async def run_input_encoder_server_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
247
248
        torch.cuda.empty_cache()
        gc.collect()
249
250
251
252
253
254
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    async def run_input_encoder_server_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
255
256
        torch.cuda.empty_cache()
        gc.collect()
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    async def run_dit_server(self, kwargs):
        if self.inputs.get("image_encoder_output", None) is not None:
            self.inputs["image_encoder_output"].pop("img", None)
        dit_output = await self.post_task(
            task_type="dit",
            urls=self.config["sub_servers"]["dit"],
            message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
            device="cuda",
        )
        return dit_output, None

    async def run_vae_decoder_server(self, latents, generator):
        images = await self.post_task(
            task_type="vae_model/decoder",
            urls=self.config["sub_servers"]["vae_model"],
            message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
            device="cpu",
        )
        return images

279
    async def run_pipeline(self):
helloyongyang's avatar
helloyongyang committed
280
        if self.config["use_prompt_enhancer"]:
281
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
282
283
284
285
        self.inputs = await self.run_input_encoder()
        kwargs = self.set_target_shape()
        latents, generator = await self.run_dit(kwargs)
        images = await self.run_vae_decoder(latents, generator)
helloyongyang's avatar
helloyongyang committed
286
        self.save_video(images)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
287
288
        del latents, generator, images
        torch.cuda.empty_cache()
289
        gc.collect()