default_runner.py 9.51 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import gc
2
3
import requests
from requests.exceptions import RequestException
helloyongyang's avatar
helloyongyang committed
4
5
import torch
import torch.distributed as dist
6
7
import torchvision.transforms.functional as TF
from PIL import Image
helloyongyang's avatar
helloyongyang committed
8
9
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video
10
from lightx2v.utils.generate_task_id import generate_task_id
helloyongyang's avatar
helloyongyang committed
11
from lightx2v.utils.envs import *
12
from lightx2v.utils.service_utils import TensorTransporter, ImageTransporter
13
from loguru import logger
PengGao's avatar
PengGao committed
14
from .base_runner import BaseRunner
15
16


PengGao's avatar
PengGao committed
17
class DefaultRunner(BaseRunner):
helloyongyang's avatar
helloyongyang committed
18
    def __init__(self, config):
PengGao's avatar
PengGao committed
19
        super().__init__(config)
20
        self.has_prompt_enhancer = False
PengGao's avatar
PengGao committed
21
        self.progress_callback = None
22
23
24
25
26
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
27
28
29
        if not self.has_prompt_enhancer:
            self.config["use_prompt_enhancer"] = False
        self.set_init_device()
30

31
    def init_modules(self):
gushiqiao's avatar
gushiqiao committed
32
        logger.info("Initializing runner modules...")
33
34
35
36
37
38
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        self.run_dit = self._run_dit_local
        self.run_vae_decoder = self._run_vae_decoder_local
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
39
        else:
40
            self.run_input_encoder = self._run_input_encoder_local_t2v
41

42
    def set_init_device(self):
43
44
45
46
        if self.config["parallel_attn_type"]:
            cur_rank = dist.get_rank()
            torch.cuda.set_device(cur_rank)
        if self.config.cpu_offload:
47
            self.init_device = torch.device("cpu")
48
        else:
49
            self.init_device = torch.device("cuda")
50
51
52

    @ProfilingContext("Load models")
    def load_model(self):
53
54
55
56
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

helloyongyang's avatar
helloyongyang committed
77
78
    def set_inputs(self, inputs):
        self.config["prompt"] = inputs.get("prompt", "")
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
79
        self.config["use_prompt_enhancer"] = False
80
        if self.has_prompt_enhancer:
81
            self.config["use_prompt_enhancer"] = inputs.get("use_prompt_enhancer", False)  # Reset use_prompt_enhancer from clinet side.
helloyongyang's avatar
helloyongyang committed
82
83
84
        self.config["negative_prompt"] = inputs.get("negative_prompt", "")
        self.config["image_path"] = inputs.get("image_path", "")
        self.config["save_video_path"] = inputs.get("save_video_path", "")
PengGao's avatar
PengGao committed
85
86
87
88
89
90
91
92
        self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
        self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
        self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
        self.config["audio_path"] = inputs.get("audio_path", "")  # for wan-audio
        self.config["video_duration"] = inputs.get("video_duration", 5)  # for wan-audio

        # self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
        # self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
helloyongyang's avatar
helloyongyang committed
93

PengGao's avatar
PengGao committed
94
95
96
    def set_progress_callback(self, callback):
        self.progress_callback = callback

97
    def run(self):
PengGao's avatar
PengGao committed
98
99
100
        total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
101
102
103
104
105
106
107
108
109
110

            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4Debug("infer"):
                self.model.infer(self.inputs)

            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()

PengGao's avatar
PengGao committed
111
112
113
            if self.progress_callback:
                self.progress_callback(step_index + 1, total_steps)

114
115
        return self.model.scheduler.latents, self.model.scheduler.generator

116
117
118
119
120
121
122
    def run_step(self, step_index=0):
        self.init_scheduler()
        self.inputs = self.run_input_encoder()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        self.model.scheduler.step_pre(step_index=step_index)
        self.model.infer(self.inputs)
        self.model.scheduler.step_post()
helloyongyang's avatar
helloyongyang committed
123
124

    def end_run(self):
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
125
126
        self.model.scheduler.clear()
        del self.inputs, self.model.scheduler
gushiqiao's avatar
gushiqiao committed
127
128
129
130
131
132
133
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
                self.model.transformer_infer.weights_stream_mgr.clear()
            if hasattr(self.model.transformer_weights, "clear"):
                self.model.transformer_weights.clear()
            self.model.pre_weight.clear()
            self.model.post_weight.clear()
134
            del self.model
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
135
        torch.cuda.empty_cache()
136
        gc.collect()
helloyongyang's avatar
helloyongyang committed
137

138
    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
139
    def _run_input_encoder_local_i2v(self):
140
141
142
143
144
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        img = Image.open(self.config["image_path"]).convert("RGB")
        clip_encoder_out = self.run_image_encoder(img)
        vae_encode_out, kwargs = self.run_vae_encoder(img)
        text_encoder_output = self.run_text_encoder(prompt, img)
145
146
        torch.cuda.empty_cache()
        gc.collect()
147
148
149
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext("Run Encoders")
PengGao's avatar
PengGao committed
150
    def _run_input_encoder_local_t2v(self):
151
152
        prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
        text_encoder_output = self.run_text_encoder(prompt, None)
153
154
        torch.cuda.empty_cache()
        gc.collect()
155
156
157
158
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }
159
160

    @ProfilingContext("Run DiT")
PengGao's avatar
PengGao committed
161
    def _run_dit_local(self, kwargs):
gushiqiao's avatar
gushiqiao committed
162
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
163
            self.model = self.load_transformer()
164
165
166
167
168
169
170
        self.init_scheduler()
        self.model.scheduler.prepare(self.inputs["image_encoder_output"])
        latents, generator = self.run()
        self.end_run()
        return latents, generator

    @ProfilingContext("Run VAE Decoder")
PengGao's avatar
PengGao committed
171
    def _run_vae_decoder_local(self, latents, generator):
gushiqiao's avatar
gushiqiao committed
172
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
173
            self.vae_decoder = self.load_vae_decoder()
174
        images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gushiqiao's avatar
gushiqiao committed
175
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
gushiqiao's avatar
gushiqiao committed
176
            del self.vae_decoder
177
178
            torch.cuda.empty_cache()
            gc.collect()
helloyongyang's avatar
helloyongyang committed
179
180
181
182
183
        return images

    @ProfilingContext("Save video")
    def save_video(self, images):
        if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
184
            self.save_video_func(images)
helloyongyang's avatar
helloyongyang committed
185

186
187
188
189
190
    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
191
192
193
194
195
196
197
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
198
199
200
201
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

202
203
204
    def run_pipeline(self, save_video=True):
        if self.config["use_prompt_enhancer"]:
            self.config["prompt_enhanced"] = self.post_prompt_enhancer()
PengGao's avatar
PengGao committed
205

206
        self.inputs = self.run_input_encoder()
PengGao's avatar
PengGao committed
207

208
        kwargs = self.set_target_shape()
PengGao's avatar
PengGao committed
209

210
        latents, generator = self.run_dit(kwargs)
PengGao's avatar
PengGao committed
211

212
        images = self.run_vae_decoder(latents, generator)
PengGao's avatar
PengGao committed
213

214
215
        if save_video:
            self.save_video(images)
PengGao's avatar
PengGao committed
216

217
218
219
        del latents, generator
        torch.cuda.empty_cache()
        gc.collect()
PengGao's avatar
PengGao committed
220

221
        return images