qwen_image_runner.py 11.6 KB
Newer Older
1
import gc
2
import math
3
4

import torch
5
6
import torchvision.transforms.functional as TF
from PIL import Image
7
8
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
yihuiwen's avatar
yihuiwen committed
14
from lightx2v.server.metrics import monitor_cli
yihuiwen's avatar
yihuiwen committed
15
from lightx2v.utils.envs import *
16
from lightx2v.utils.profiler import *
17
18
19
from lightx2v.utils.registry_factory import RUNNER_REGISTER


20
21
22
23
24
25
26
27
28
29
def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


30
31
32
33
34
35
36
37
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)

38
    @ProfilingContext4DebugL2("Load models")
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        model = QwenImageTransformerModel(self.config)
        return model

    def load_text_encoder(self):
        text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLQwenImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
66
        self.run_dit = self._run_dit_local
67
        if self.config["task"] == "t2i":
68
69
70
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
71
72
73
        else:
            assert NotImplementedError

Watebear's avatar
Watebear committed
74
75
        self.model.set_scheduler(self.scheduler)

76
    @ProfilingContext4DebugL2("Run DiT")
77
78
79
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
80
        self.model.scheduler.prepare(self.input_info)
81
82
83
        latents, generator = self.run(total_steps)
        return latents, generator

84
    @ProfilingContext4DebugL2("Run Encoders")
85
    def _run_input_encoder_local_t2i(self):
86
87
        prompt = self.input_info.prompt
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
Gu Shiqiao's avatar
Gu Shiqiao committed
88
89
        if hasattr(torch, self.run_device):
            torch_module = getattr(torch, self.run_device)
Kane's avatar
Kane committed
90
            torch_module.empty_cache()
91
92
93
94
95
96
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

97
98
99
100
101
102
103
104
    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
        if GET_RECORDER_MODE():
            width, height = img_ori.size
            monitor_cli.lightx2v_input_image_len.observe(width * height)
Gu Shiqiao's avatar
Gu Shiqiao committed
105
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
106
107
108
        self.input_info.original_size.append(img_ori.size)
        return img, img_ori

109
    @ProfilingContext4DebugL2("Run Encoders")
110
    def _run_input_encoder_local_i2i(self):
111
112
113
114
115
116
        image_paths_list = self.input_info.image_path.split(",")
        images_list = []
        for image_path in image_paths_list:
            _, image = self.read_image_input(image_path)
            images_list.append(image)

117
        prompt = self.input_info.prompt
118
119
120
121
122
123
        text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt)

        image_encoder_output_list = []
        for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
            image_encoder_output = self.run_vae_encoder(image=vae_image)
            image_encoder_output_list.append(image_encoder_output)
Gu Shiqiao's avatar
Gu Shiqiao committed
124
125
        if hasattr(torch, self.run_device):
            torch_module = getattr(torch, self.run_device)
Kane's avatar
Kane committed
126
            torch_module.empty_cache()
127
128
129
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
130
            "image_encoder_output": image_encoder_output_list,
131
132
        }

yihuiwen's avatar
yihuiwen committed
133
    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
134
    def run_text_encoder(self, text, image_list=None, neg_prompt=None):
yihuiwen's avatar
yihuiwen committed
135
136
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
137
        text_encoder_output = {}
138
        if self.config["task"] == "t2i":
139
            prompt_embeds, prompt_embeds_mask, _ = self.text_encoders[0].infer([text])
140
141
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
142
            if self.config["do_true_cfg"] and neg_prompt is not None:
143
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt])
144
145
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
146
        elif self.config["task"] == "i2i":
147
            prompt_embeds, prompt_embeds_mask, image_info = self.text_encoders[0].infer([text], image_list)
148
149
150
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
            text_encoder_output["image_info"] = image_info
151
            if self.config["do_true_cfg"] and neg_prompt is not None:
152
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt], image_list)
153
154
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
155
156
        return text_encoder_output

157
    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
158
    def run_vae_encoder(self, image):
159
        image_latents = self.vae.encode_vae_image(image, self.input_info)
160
161
        return {"image_latents": image_latents}

162
163
164
165
166
167
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

168
            with ProfilingContext4DebugL1("step_pre"):
169
170
                self.model.scheduler.step_pre(step_index=step_index)

171
            with ProfilingContext4DebugL1("🚀 infer_main"):
172
173
                self.model.infer(self.inputs)

174
            with ProfilingContext4DebugL1("step_post"):
175
176
177
178
179
180
181
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

182
    def set_target_shape(self):
183
184
        if not self.config["_auto_resize"]:
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
185
        else:
186
            width, height = self.input_info.original_size[-1]
187
188
            calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
            multiple_of = self.vae.vae_scale_factor * 2
189
190
191
192
            width = calculated_width // multiple_of * multiple_of
            height = calculated_height // multiple_of * multiple_of
            self.input_info.auto_width = width
            self.input_info.auto_hight = height
193

194
195
        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
196
197
        height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
198
        num_channels_latents = self.model.in_channels // 4
199
200
201
202
203
204
205
        self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)

    def set_img_shapes(self):
        if self.config["task"] == "t2i":
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
            img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
        elif self.config["task"] == "i2i":
206
207
208
209
            img_shapes = [[(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2)]]
            for image_height, image_width in self.inputs["text_encoder_output"]["image_info"]["vae_image_info_list"]:
                img_shapes[0].append((1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2))

210
        self.inputs["img_shapes"] = img_shapes
211
212

    def init_scheduler(self):
Watebear's avatar
Watebear committed
213
        self.scheduler = QwenImageScheduler(self.config)
214
215
216
217
218
219
220

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

221
    @ProfilingContext4DebugL2("Load models")
222
223
224
225
226
227
228
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None

yihuiwen's avatar
yihuiwen committed
229
230
231
232
233
234
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["QwenImageRunner"],
    )
235
    def run_vae_decoder(self, latents):
236
237
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae()
238
        images = self.vae.decode(latents, self.input_info)
239
240
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
Gu Shiqiao's avatar
Gu Shiqiao committed
241
242
            if hasattr(torch, self.run_device):
                torch_module = getattr(torch, self.run_device)
Kane's avatar
Kane committed
243
                torch_module.empty_cache()
244
245
246
            gc.collect()
        return images

247
248
    def run_pipeline(self, input_info):
        self.input_info = input_info
249
250
251

        self.inputs = self.run_input_encoder()
        self.set_target_shape()
252
253
        self.set_img_shapes()

254
        latents, generator = self.run_dit()
255
256
        images = self.run_vae_decoder(latents)
        self.end_run()
257
258

        image = images[0]
259
        image.save(f"{input_info.save_result_path}")
260
261

        del latents, generator
Gu Shiqiao's avatar
Gu Shiqiao committed
262
263
        if hasattr(torch, self.run_device):
            torch_module = getattr(torch, self.run_device)
Kane's avatar
Kane committed
264
            torch_module.empty_cache()
265
266
267
268
        gc.collect()

        # Return (images, audio) - audio is None for default runner
        return images, None