"docker/china/requirements.txt" did not exist on "9ececf3a1ec53db36ea05ac9016160dcc49182fd"
default_runner.py 10.4 KB
Newer Older
1
import asyncio
helloyongyang's avatar
helloyongyang committed
2
import gc
3
import aiohttp
4
5
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
6
7
import torch
import torch.distributed as dist
8
9
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
10
11
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
12
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
15
from loguru import logger
16
17
18


class DefaultRunner:
helloyongyang's avatar
helloyongyang committed
19
20
    def __init__(self, config):
        self.config = config
21
        self.has_prompt_enhancer = False
22
23
24
25
26
27
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")

28
29
30
31
32
        if self.config["mode"] == "split_server":
            self.model = self.load_transformer()
            self.text_encoders, self.vae_model, self.image_encoder = None, None, None
            self.tensor_transporter = TensorTransporter()
            self.image_transporter = ImageTransporter()
33
34
35
36
37
38
            if not self.check_sub_servers("text_encoders"):
                raise ValueError("No text encoder server available")
            if "wan2.1" in self.config["model_cls"] and not self.check_sub_servers("image_encoder"):
                raise ValueError("No image encoder server available")
            if not self.check_sub_servers("vae_model"):
                raise ValueError("No vae model server available")
39
40
        else:
            self.model, self.text_encoders, self.vae_model, self.image_encoder = self.load_model()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
61
62
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
63
        if self.has_prompt_enhancer:
64
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
65
66
67
68
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")

69
70
71
72
73
74
75
76
77
78
79
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
                    response = requests.post(f"{url}/v1/local/prompt_enhancer/generate", json={"task_id": generate_task_id(), "prompt": self.config["prompt"]})
                    self.config["prompt_enhanced"] = response.json()["output"]
                    logger.info(f"Enhanced prompt: {self.config['prompt_enhanced']}")
                    return

    async def post_encoders(self, prompt, img=None, n_prompt=None, i2v=False):
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        tasks = []
        img_byte = self.image_transporter.prepare_image(img) if img is not None else None
        if i2v:
            if "wan2.1" in self.config["model_cls"]:
                tasks.append(
                    asyncio.create_task(
                        self.post_task(task_type="image_encoder", urls=self.config["sub_servers"]["image_encoder"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")
                    )
                )
            tasks.append(
                asyncio.create_task(
                    self.post_task(task_type="vae_model/encoder", urls=self.config["sub_servers"]["vae_model"], message={"task_id": generate_task_id(), "img": img_byte}, device="cuda")
                )
            )
        tasks.append(
            asyncio.create_task(
96
97
98
99
100
101
                self.post_task(
                    task_type="text_encoders",
                    urls=self.config["sub_servers"]["text_encoders"],
                    message={"task_id": generate_task_id(), "text": prompt, "img": img_byte, "n_prompt": n_prompt},
                    device="cuda",
                )
102
103
104
            )
        )
        results = await asyncio.gather(*tasks)
105
        # clip_encoder, vae_encoder, text_encoders
106
107
108
109
110
111
112
113
        if not i2v:
            return None, None, results[0]
        if "wan2.1" in self.config["model_cls"]:
            return results[0], results[1], results[2]
        else:
            return None, results[0], results[1]

    async def run_input_encoder(self):
helloyongyang's avatar
helloyongyang committed
114
        image_encoder_output = None
115
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
116
        n_prompt = self.config.get("negative_prompt", "")
117
118
119
120
        i2v = self.config["task"] == "i2v"
        img = Image.open(self.config["image_path"]).convert("RGB") if i2v else None
        with ProfilingContext("Run Encoders"):
            if self.config["mode"] == "split_server":
121
                clip_encoder_out, vae_encode_out, text_encoder_output = await self.post_encoders(prompt, img, n_prompt, i2v)
122
123
124
125
126
127
128
129
130
131
132
                if i2v:
                    if self.config["model_cls"] in ["hunyuan"]:
                        image_encoder_output = {"img": img, "img_latents": vae_encode_out}
                    elif "wan2.1" in self.config["model_cls"]:
                        image_encoder_output = {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
                    else:
                        raise ValueError(f"Unsupported model class: {self.config['model_cls']}")
            else:
                if i2v:
                    image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
                text_encoder_output = self.run_text_encoder(prompt, self.text_encoders, self.config, image_encoder_output)
helloyongyang's avatar
helloyongyang committed
133
134
135
136
137
        self.set_target_shape()
        self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}

        gc.collect()
        torch.cuda.empty_cache()
138
139
140

    def run(self):
        for step_index in range(self.model.scheduler.infer_steps):
141
            logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
142
143
144
145
146
147
148
149
150
151
152
153

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

        return self.model.scheduler.latents, self.model.scheduler.generator

154
    async def run_step(self, step_index=0):
helloyongyang's avatar
helloyongyang committed
155
        self.init_scheduler()
156
        await self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
157
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
158
159
160
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
161
162

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
163
164
165
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
        torch.cuda.empty_cache()
helloyongyang's avatar
helloyongyang committed
166
167

    @ProfilingContext("Run VAE")
168
169
170
171
172
173
174
175
176
177
    async def run_vae(self, latents, generator):
        if self.config["mode"] == "split_server":
            images = await self.post_task(
                task_type="vae_model/decoder",
                urls=self.config["sub_servers"]["vae_model"],
                message={"task_id": generate_task_id(), "latents": self.tensor_transporter.prepare_tensor(latents)},
                device="cpu",
            )
        else:
            images = self.vae_model.decode(latents, generator=generator, config=self.config)
helloyongyang's avatar
helloyongyang committed
178
179
180
181
182
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
183
            if self.config.model_cls in ["wan2.1", "wan2.1_causvid", "wan2.1_skyreels_v2_df"]:
184
                cache_video(tensor=images, save_file=self.config.save_video_path, fps=self.config.get("fps", 16), nrow=1, normalize=True, value_range=(-1, 1))
helloyongyang's avatar
helloyongyang committed
185
            else:
186
                save_videos_grid(images, self.config.save_video_path, fps=self.config.get("fps", 24))
helloyongyang's avatar
helloyongyang committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    async def post_task(self, task_type, urls, message, device="cuda"):
        while True:
            for url in urls:
                async with aiohttp.ClientSession() as session:
                    async with session.get(f"{url}/v1/local/{task_type}/generate/service_status") as response:
                        status = await response.json()
                    if status["service_status"] == "idle":
                        async with session.post(f"{url}/v1/local/{task_type}/generate", json=message) as response:
                            result = await response.json()
                            if result["kwargs"] is not None:
                                for k, v in result["kwargs"].items():
                                    setattr(self.config, k, v)
                            return self.tensor_transporter.load_tensor(result["output"], device)
            await asyncio.sleep(0.1)

    async def run_pipeline(self):
helloyongyang's avatar
helloyongyang committed
204
        if self.config["use_prompt_enhancer"]:
205
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
helloyongyang's avatar
helloyongyang committed
206
        self.init_scheduler()
207
        await self.run_input_encoder()
helloyongyang's avatar
helloyongyang committed
208
209
210
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
211
        images = await self.run_vae(latents, generator)
helloyongyang's avatar
helloyongyang committed
212
        self.save_video(images)
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
213
214
215
        del latents, generator, images
        gc.collect()
        torch.cuda.empty_cache()