qwen_image_runner.py 11.3 KB
Newer Older
1
import gc
2
import math
3
4

import torch
5
6
import torchvision.transforms.functional as TF
from PIL import Image
7
8
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
yihuiwen's avatar
yihuiwen committed
14
from lightx2v.server.metrics import monitor_cli
yihuiwen's avatar
yihuiwen committed
15
from lightx2v.utils.envs import *
16
from lightx2v.utils.profiler import *
17
from lightx2v.utils.registry_factory import RUNNER_REGISTER
18
19
20
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)
21
22


23
24
25
26
27
28
29
30
31
32
def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


33
34
35
36
37
38
39
40
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)

41
    @ProfilingContext4DebugL2("Load models")
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        model = QwenImageTransformerModel(self.config)
        return model

    def load_text_encoder(self):
        text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLQwenImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
69
        self.run_dit = self._run_dit_local
70
        if self.config["task"] == "t2i":
71
72
73
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
74
75
76
        else:
            assert NotImplementedError

Watebear's avatar
Watebear committed
77
78
        self.model.set_scheduler(self.scheduler)

79
    @ProfilingContext4DebugL2("Run DiT")
80
81
82
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
83
        self.model.scheduler.prepare(self.input_info)
84
85
86
        latents, generator = self.run(total_steps)
        return latents, generator

87
    @ProfilingContext4DebugL2("Run Encoders")
88
    def _run_input_encoder_local_t2i(self):
89
90
        prompt = self.input_info.prompt
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
91
        torch_device_module.empty_cache()
92
93
94
95
96
97
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

98
99
100
101
102
103
104
105
    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
        if GET_RECORDER_MODE():
            width, height = img_ori.size
            monitor_cli.lightx2v_input_image_len.observe(width * height)
106
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
107
108
109
        self.input_info.original_size.append(img_ori.size)
        return img, img_ori

110
    @ProfilingContext4DebugL2("Run Encoders")
111
    def _run_input_encoder_local_i2i(self):
112
113
114
115
116
117
        image_paths_list = self.input_info.image_path.split(",")
        images_list = []
        for image_path in image_paths_list:
            _, image = self.read_image_input(image_path)
            images_list.append(image)

118
        prompt = self.input_info.prompt
119
120
121
122
123
124
        text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt)

        image_encoder_output_list = []
        for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
            image_encoder_output = self.run_vae_encoder(image=vae_image)
            image_encoder_output_list.append(image_encoder_output)
125
        torch_device_module.empty_cache()
126
127
128
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
129
            "image_encoder_output": image_encoder_output_list,
130
131
        }

yihuiwen's avatar
yihuiwen committed
132
    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
133
    def run_text_encoder(self, text, image_list=None, neg_prompt=None):
yihuiwen's avatar
yihuiwen committed
134
135
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
136
        text_encoder_output = {}
137
        if self.config["task"] == "t2i":
138
            prompt_embeds, prompt_embeds_mask, _ = self.text_encoders[0].infer([text])
139
140
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
141
            if self.config["do_true_cfg"] and neg_prompt is not None:
142
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt])
143
144
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
145
        elif self.config["task"] == "i2i":
146
            prompt_embeds, prompt_embeds_mask, image_info = self.text_encoders[0].infer([text], image_list)
147
148
149
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
            text_encoder_output["image_info"] = image_info
150
            if self.config["do_true_cfg"] and neg_prompt is not None:
151
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt], image_list)
152
153
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
154
155
        return text_encoder_output

156
    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
157
    def run_vae_encoder(self, image):
158
        image_latents = self.vae.encode_vae_image(image, self.input_info)
159
160
        return {"image_latents": image_latents}

161
162
163
164
165
166
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

167
            with ProfilingContext4DebugL1("step_pre"):
168
169
                self.model.scheduler.step_pre(step_index=step_index)

170
            with ProfilingContext4DebugL1("🚀 infer_main"):
171
172
                self.model.infer(self.inputs)

173
            with ProfilingContext4DebugL1("step_post"):
174
175
176
177
178
179
180
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

181
    def set_target_shape(self):
182
183
        if not self.config["_auto_resize"]:
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
184
        else:
185
            width, height = self.input_info.original_size[-1]
186
187
            calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
            multiple_of = self.vae.vae_scale_factor * 2
188
189
190
191
            width = calculated_width // multiple_of * multiple_of
            height = calculated_height // multiple_of * multiple_of
            self.input_info.auto_width = width
            self.input_info.auto_hight = height
192

193
194
        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
195
196
        height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
197
        num_channels_latents = self.model.in_channels // 4
198
199
200
201
202
203
204
        self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)

    def set_img_shapes(self):
        if self.config["task"] == "t2i":
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
            img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
        elif self.config["task"] == "i2i":
205
206
207
208
            img_shapes = [[(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2)]]
            for image_height, image_width in self.inputs["text_encoder_output"]["image_info"]["vae_image_info_list"]:
                img_shapes[0].append((1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2))

209
        self.inputs["img_shapes"] = img_shapes
210
211

    def init_scheduler(self):
Watebear's avatar
Watebear committed
212
        self.scheduler = QwenImageScheduler(self.config)
213
214
215
216
217
218
219

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

220
    @ProfilingContext4DebugL2("Load models")
221
222
223
224
225
226
227
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None

yihuiwen's avatar
yihuiwen committed
228
229
230
231
232
233
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["QwenImageRunner"],
    )
234
    def run_vae_decoder(self, latents):
235
236
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae()
237
        images = self.vae.decode(latents, self.input_info)
238
239
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
240
            torch_device_module.empty_cache()
241
242
243
            gc.collect()
        return images

244
245
    def run_pipeline(self, input_info):
        self.input_info = input_info
246
247
248

        self.inputs = self.run_input_encoder()
        self.set_target_shape()
249
250
        self.set_img_shapes()

251
        latents, generator = self.run_dit()
252
253
        images = self.run_vae_decoder(latents)
        self.end_run()
254
255

        image = images[0]
256
        image.save(f"{input_info.save_result_path}")
257
258

        del latents, generator
259
        torch_device_module.empty_cache()
260
261
262
263
        gc.collect()

        # Return (images, audio) - audio is None for default runner
        return images, None