default_runner.py 13.1 KB
Newer Older
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
import gc
3
import aiohttp
4
5
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
6
7
import torch
import torch.distributed as dist
8
9
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
12
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
15
from loguru import logger
16
17
18


class DefaultRunner:
helloyongyang's avatar
helloyongyang committed
19
20
    def __init__(self, config):
        self.config = config
21
        self.has_prompt_enhancer = False
22
23
24
25
26
27
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")

28
    def init_modules(self):
29
        self.set_init_device()
30
31
32
        if self.config["mode"] == "split_server":
            self.tensor_transporter = TensorTransporter()
            self.image_transporter = ImageTransporter()
33
34
            if not self.check_sub_servers("dit"):
                raise ValueError("No dit server available")
35
36
            if not self.check_sub_servers("text_encoders"):
                raise ValueError("No text encoder server available")
37
38
39
            if self.config["task"] == "i2v":
                if not self.check_sub_servers("image_encoder"):
                    raise ValueError("No image encoder server available")
40
            if not self.check_sub_servers("vae_model"):
41
42
43
44
45
46
47
48
                raise ValueError("No vae server available")
            self.run_dit = self.run_dit_server
            self.run_vae_decoder = self.run_vae_decoder_server
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_server_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_server_t2v
        else:
49
50
            if not self.config.get("lazy_load", False):
                self.load_model()
51
52
53
54
55
56
57
            self.run_dit = self.run_dit_local
            self.run_vae_decoder = self.run_vae_decoder_local
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_local_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_local_t2v

58
    def set_init_device(self):
59
60
61
62
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
63
            self.init_device = torch.device("cpu")
64
        else:
65
            self.init_device = torch.device("cuda")
66
67
68

    @ProfilingContext("Load models")
    def load_model(self):
69
70
71
72
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
93
94
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
95
        if self.has_prompt_enhancer:
96
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
97
98
99
100
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")

101
102
    def run(self):
        for step_index in range(self.model.scheduler.infer_steps):
103
            logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
104
105
106
107
108
109
110
111
112
113
114
115

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

        return self.model.scheduler.latents, self.model.scheduler.generator

116
    async def run_step(self, step_index=0):
helloyongyang's avatar
helloyongyang committed
117
        self.init_scheduler()
118
        await self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
119
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
120
121
122
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
123
124

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
125
126
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
127
128
        if self.config.get("lazy_load", False):
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
129
        torch.cuda.empty_cache()
130
        gc.collect()
helloyongyang's avatar
helloyongyang committed
131

132
133
134
135
136
137
138
    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
        vae_encode_out, kwargs = self.run_vae_encoder(img)
        text_encoder_output = self.run_text_encoder(prompt, img)
139
140
        torch.cuda.empty_cache()
        gc.collect()
141
142
143
144
145
146
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
147
148
        torch.cuda.empty_cache()
        gc.collect()
149
150
151
152
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    @ProfilingContext("Run DiT")
    async def run_dit_local(self, kwargs):
153
154
        if self.config.get("lazy_load", False):
            self.model = self.load_transformer()
155
156
157
158
159
160
161
162
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
    async def run_vae_decoder_local(self, latents, generator):
163
164
        if self.config.get("lazy_load", False):
            self.vae_decoder = self.load_vae_decoder()
165
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
166
167
168
        if self.config.get("lazy_load", False):
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
169
170
171
172
173
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
174
            self.save_video_func(images)
helloyongyang's avatar
helloyongyang committed
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    async def post_task(self, task_type, urls, message, device="cuda"):
        while True:
            for url in urls:
                async with aiohttp.ClientSession() as session:
                    async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
                        status = await response.json()
                    if status["service_status"] == "idle":
                        async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
                            result = await response.json()
                            if result["kwargs"] is not None:
                                for k, v in result["kwargs"].items():
                                    setattr(self.config, k, v)
                            return self.tensor_transporter.load_tensor(result["output"], device)
            await asyncio.sleep(0.1)

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
                    response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

    async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False):
        tasks = []
        img_byte = self.image_transporter.prepare_image(img)
        tasks.append(
            asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # clip_encoder, vae_encoder, text_encoders
        return results[0], results[1], results[2]

    async def post_encoders_t2v(self, prompt, n_prompt=None):
        tasks = []
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # text_encoders
        return results[0]

    async def run_input_encoder_server_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
245
246
        torch.cuda.empty_cache()
        gc.collect()
247
248
249
250
251
252
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    async def run_input_encoder_server_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
253
254
        torch.cuda.empty_cache()
        gc.collect()
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    async def run_dit_server(self, kwargs):
        if self.inputs.get("image_encoder_output", None) is not None:
            self.inputs["image_encoder_output"].pop("img", None)
        dit_output = await self.post_task(
            task_type="dit",
            urls=self.config["sub_servers"]["dit"],
            message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
            device="cuda",
        )
        return dit_output, None

    async def run_vae_decoder_server(self, latents, generator):
        images = await self.post_task(
            task_type="vae_model/decoder",
            urls=self.config["sub_servers"]["vae_model"],
            message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
            device="cpu",
        )
        return images

277
    async def run_pipeline(self):
helloyongyang's avatar
helloyongyang committed
278
        if self.config["use_prompt_enhancer"]:
279
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
280
281
282
283
        self.inputs = await self.run_input_encoder()
        kwargs = self.set_target_shape()
        latents, generator = await self.run_dit(kwargs)
        images = await self.run_vae_decoder(latents, generator)
helloyongyang's avatar
helloyongyang committed
284
        self.save_video(images)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
285
286
        del latents, generator, images
        torch.cuda.empty_cache()
287
        gc.collect()