qwen_image_runner.py 11.9 KB
Newer Older
1
import gc
2
import math
3
4

import torch
5
6
import torchvision.transforms.functional as TF
from PIL import Image
7
8
9
from loguru import logger

from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
10
from lightx2v.models.networks.qwen_image.lora_adapter import QwenImageLoraWrapper
11
12
13
14
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
yihuiwen's avatar
yihuiwen committed
15
from lightx2v.server.metrics import monitor_cli
yihuiwen's avatar
yihuiwen committed
16
from lightx2v.utils.envs import *
17
from lightx2v.utils.profiler import *
18
from lightx2v.utils.registry_factory import RUNNER_REGISTER
19
20
21
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)
22
23


24
25
26
27
28
29
30
31
32
33
def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


34
35
36
37
38
39
40
41
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)

42
    @ProfilingContext4DebugL2("Load models")
43
44
45
46
47
48
49
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        model = QwenImageTransformerModel(self.config)
50
51
52
53
54
55
56
57
58
        if self.config.get("lora_configs") and self.config.lora_configs:
            assert not self.config.get("dit_quantized", False)
            lora_wrapper = QwenImageLoraWrapper(model)
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                lora_name = lora_wrapper.load_lora(lora_path)
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        return model

    def load_text_encoder(self):
        text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLQwenImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
79
        self.run_dit = self._run_dit_local
80
        if self.config["task"] == "t2i":
81
82
83
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
84
85
86
        else:
            assert NotImplementedError

Watebear's avatar
Watebear committed
87
88
        self.model.set_scheduler(self.scheduler)

89
    @ProfilingContext4DebugL2("Run DiT")
90
91
92
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
93
        self.model.scheduler.prepare(self.input_info)
94
95
96
        latents, generator = self.run(total_steps)
        return latents, generator

97
    @ProfilingContext4DebugL2("Run Encoders")
98
    def _run_input_encoder_local_t2i(self):
99
100
        prompt = self.input_info.prompt
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
101
        torch_device_module.empty_cache()
102
103
104
105
106
107
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

108
109
110
111
112
113
114
115
    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
        if GET_RECORDER_MODE():
            width, height = img_ori.size
            monitor_cli.lightx2v_input_image_len.observe(width * height)
116
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
117
118
119
        self.input_info.original_size.append(img_ori.size)
        return img, img_ori

120
    @ProfilingContext4DebugL2("Run Encoders")
121
    def _run_input_encoder_local_i2i(self):
122
123
124
125
126
127
        image_paths_list = self.input_info.image_path.split(",")
        images_list = []
        for image_path in image_paths_list:
            _, image = self.read_image_input(image_path)
            images_list.append(image)

128
        prompt = self.input_info.prompt
129
130
131
132
133
134
        text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt)

        image_encoder_output_list = []
        for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
            image_encoder_output = self.run_vae_encoder(image=vae_image)
            image_encoder_output_list.append(image_encoder_output)
135
        torch_device_module.empty_cache()
136
137
138
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
139
            "image_encoder_output": image_encoder_output_list,
140
141
        }

yihuiwen's avatar
yihuiwen committed
142
    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
143
    def run_text_encoder(self, text, image_list=None, neg_prompt=None):
yihuiwen's avatar
yihuiwen committed
144
145
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
146
        text_encoder_output = {}
147
        if self.config["task"] == "t2i":
148
            prompt_embeds, prompt_embeds_mask, _ = self.text_encoders[0].infer([text])
149
150
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
151
            if self.config["do_true_cfg"] and neg_prompt is not None:
152
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt])
153
154
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
155
        elif self.config["task"] == "i2i":
156
            prompt_embeds, prompt_embeds_mask, image_info = self.text_encoders[0].infer([text], image_list)
157
158
159
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
            text_encoder_output["image_info"] = image_info
160
            if self.config["do_true_cfg"] and neg_prompt is not None:
161
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt], image_list)
162
163
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
164
165
        return text_encoder_output

166
    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
167
    def run_vae_encoder(self, image):
168
        image_latents = self.vae.encode_vae_image(image, self.input_info)
169
170
        return {"image_latents": image_latents}

171
172
173
174
175
176
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

177
            with ProfilingContext4DebugL1("step_pre"):
178
179
                self.model.scheduler.step_pre(step_index=step_index)

180
            with ProfilingContext4DebugL1("🚀 infer_main"):
181
182
                self.model.infer(self.inputs)

183
            with ProfilingContext4DebugL1("step_post"):
184
185
186
187
188
189
190
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

191
    def set_target_shape(self):
192
193
        if not self.config["_auto_resize"]:
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
194
        else:
195
            width, height = self.input_info.original_size[-1]
196
197
            calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
            multiple_of = self.vae.vae_scale_factor * 2
198
199
200
201
            width = calculated_width // multiple_of * multiple_of
            height = calculated_height // multiple_of * multiple_of
            self.input_info.auto_width = width
            self.input_info.auto_hight = height
202

203
204
        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
205
206
        height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
207
        num_channels_latents = self.model.in_channels // 4
208
209
210
211
212
213
214
        self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)

    def set_img_shapes(self):
        if self.config["task"] == "t2i":
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
            img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
        elif self.config["task"] == "i2i":
215
216
217
218
            img_shapes = [[(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2)]]
            for image_height, image_width in self.inputs["text_encoder_output"]["image_info"]["vae_image_info_list"]:
                img_shapes[0].append((1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2))

219
        self.inputs["img_shapes"] = img_shapes
220
221

    def init_scheduler(self):
Watebear's avatar
Watebear committed
222
        self.scheduler = QwenImageScheduler(self.config)
223
224
225
226
227
228
229

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

230
    @ProfilingContext4DebugL2("Load models")
231
232
233
234
235
236
237
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None

yihuiwen's avatar
yihuiwen committed
238
239
240
241
242
243
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["QwenImageRunner"],
    )
244
    def run_vae_decoder(self, latents):
245
246
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae()
247
        images = self.vae.decode(latents, self.input_info)
248
249
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
250
            torch_device_module.empty_cache()
251
252
253
            gc.collect()
        return images

254
255
    def run_pipeline(self, input_info):
        self.input_info = input_info
256
257
258

        self.inputs = self.run_input_encoder()
        self.set_target_shape()
259
260
        self.set_img_shapes()

261
        latents, generator = self.run_dit()
262
263
        images = self.run_vae_decoder(latents)
        self.end_run()
264
265

        image = images[0]
266
        image.save(f"{input_info.save_result_path}")
267
268

        del latents, generator
269
        torch_device_module.empty_cache()
270
271
272
273
        gc.collect()

        # Return (images, audio) - audio is None for default runner
        return images, None