default_runner.py 13.4 KB
Newer Older
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
import gc
3
import aiohttp
4
5
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
6
7
import torch
import torch.distributed as dist
8
9
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
12
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
15
from loguru import logger
16
17
18


class DefaultRunner:
helloyongyang's avatar
helloyongyang committed
19
20
    def __init__(self, config):
        self.config = config
21
        self.has_prompt_enhancer = False
22
23
24
25
26
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
27
28
29
        if not self.has_prompt_enhancer:
            self.config["use_prompt_enhancer"] = False
        self.set_init_device()
30

31
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
32
        logger.info("Initializing runner modules...")
33
34
35
        if self.config["mode"] == "split_server":
            self.tensor_transporter = TensorTransporter()
            self.image_transporter = ImageTransporter()
36
37
            if not self.check_sub_servers("dit"):
                raise ValueError("No dit server available")
38
39
            if not self.check_sub_servers("text_encoders"):
                raise ValueError("No text encoder server available")
40
41
42
            if self.config["task"] == "i2v":
                if not self.check_sub_servers("image_encoder"):
                    raise ValueError("No image encoder server available")
43
            if not self.check_sub_servers("vae_model"):
44
45
46
47
48
49
50
51
                raise ValueError("No vae server available")
            self.run_dit = self.run_dit_server
            self.run_vae_decoder = self.run_vae_decoder_server
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_server_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_server_t2v
        else:
52
53
            if not self.config.get("lazy_load", False):
                self.load_model()
54
55
56
57
58
59
60
            self.run_dit = self.run_dit_local
            self.run_vae_decoder = self.run_vae_decoder_local
            if self.config["task"] == "i2v":
                self.run_input_encoder = self.run_input_encoder_local_i2v
            else:
                self.run_input_encoder = self.run_input_encoder_local_t2v

61
    def set_init_device(self):
62
63
64
65
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
66
            self.init_device = torch.device("cpu")
67
        else:
68
            self.init_device = torch.device("cuda")
69
70
71

    @ProfilingContext("Load models")
    def load_model(self):
72
73
74
75
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
96
97
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
98
        self.config["use_prompt_enhancer"] = False
99
        if self.has_prompt_enhancer:
100
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
101
102
103
104
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")

105
106
    def run(self):
        for step_index in range(self.model.scheduler.infer_steps):
107
            logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
108
109
110
111
112
113
114
115
116
117
118
119

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

        return self.model.scheduler.latents, self.model.scheduler.generator

120
    async def run_step(self, step_index=0):
helloyongyang's avatar
helloyongyang committed
121
        self.init_scheduler()
122
        await self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
123
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
124
125
126
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
127
128

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
129
130
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
131
        if self.config.get("lazy_load", False):
gushiqiao's avatar
gushiqiao committed
132
            self.model.transformer_infer.weights_stream_mgr.clear()
133
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
134
        torch.cuda.empty_cache()
135
        gc.collect()
helloyongyang's avatar
helloyongyang committed
136

137
138
139
140
141
142
143
    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
        vae_encode_out, kwargs = self.run_vae_encoder(img)
        text_encoder_output = self.run_text_encoder(prompt, img)
144
145
        torch.cuda.empty_cache()
        gc.collect()
146
147
148
149
150
151
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
    async def run_input_encoder_local_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
152
153
        torch.cuda.empty_cache()
        gc.collect()
154
155
156
157
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    @ProfilingContext("Run DiT")
    async def run_dit_local(self, kwargs):
158
159
        if self.config.get("lazy_load", False):
            self.model = self.load_transformer()
160
161
162
163
164
165
166
167
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
    async def run_vae_decoder_local(self, latents, generator):
168
169
        if self.config.get("lazy_load", False):
            self.vae_decoder = self.load_vae_decoder()
170
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
171
        if self.config.get("lazy_load", False):
gushiqiao's avatar
gushiqiao committed
172
            del self.vae_decoder
173
174
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
175
176
177
178
179
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
180
            self.save_video_func(images)
helloyongyang's avatar
helloyongyang committed
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    async def post_task(self, task_type, urls, message, device="cuda"):
        while True:
            for url in urls:
                async with aiohttp.ClientSession() as session:
                    async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
                        status = await response.json()
                    if status["service_status"] == "idle":
                        async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
                            result = await response.json()
                            if result["kwargs"] is not None:
                                for k, v in result["kwargs"].items():
                                    setattr(self.config, k, v)
                            return self.tensor_transporter.load_tensor(result["output"], device)
            await asyncio.sleep(0.1)

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
                    response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

    async def post_encoders_i2v(self, prompt, img=None, n_prompt=None, i2v=False):
        tasks = []
        img_byte = self.image_transporter.prepare_image(img)
        tasks.append(
            asyncio.create_task(self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda"))
        )
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # clip_encoder, vae_encoder, text_encoders
        return results[0], results[1], results[2]

    async def post_encoders_t2v(self, prompt, n_prompt=None):
        tasks = []
        tasks.append(
            asyncio.create_task(
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": None, "n_prompt": n_prompt},
                    device="cuda",
                )
            )
        )
        results = await asyncio.gather(*tasks)
        # text_encoders
        return results[0]

    async def run_input_encoder_server_i2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders_i2v(prompt, img, n_prompt)
251
252
        torch.cuda.empty_cache()
        gc.collect()
253
254
255
256
257
258
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    async def run_input_encoder_server_t2v(self):
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        n_prompt = self.config.get("negative_prompt", "")
        text_encoder_output = await self.post_encoders_t2v(prompt, n_prompt)
259
260
        torch.cuda.empty_cache()
        gc.collect()
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        return {"text_encoder_output": text_encoder_output, "image_encoder_output": None}

    async def run_dit_server(self, kwargs):
        if self.inputs.get("image_encoder_output", None) is not None:
            self.inputs["image_encoder_output"].pop("img", None)
        dit_output = await self.post_task(
            task_type="dit",
            urls=self.config["sub_servers"]["dit"],
            message={"task_id": generate_task_id(), "inputs": self.tensor_transporter.prepare_tensor(self.inputs), "kwargs": self.tensor_transporter.prepare_tensor(kwargs)},
            device="cuda",
        )
        return dit_output, None

    async def run_vae_decoder_server(self, latents, generator):
        images = await self.post_task(
            task_type="vae_model/decoder",
            urls=self.config["sub_servers"]["vae_model"],
            message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
            device="cpu",
        )
        return images

283
    async def run_pipeline(self):
helloyongyang's avatar
helloyongyang committed
284
        if self.config["use_prompt_enhancer"]:
285
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
286
287
288
289
        self.inputs = await self.run_input_encoder()
        kwargs = self.set_target_shape()
        latents, generator = await self.run_dit(kwargs)
        images = await self.run_vae_decoder(latents, generator)
helloyongyang's avatar
helloyongyang committed
290
        self.save_video(images)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
291
292
        del latents, generator, images
        torch.cuda.empty_cache()
293
        gc.collect()