qwen_image_runner.py 11.1 KB
Newer Older
1
import gc
2
import math
3
4

import torch
5
6
import torchvision.transforms.functional as TF
from PIL import Image
7
8
9
10
11
12
13
from loguru import logger

from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
yihuiwen's avatar
yihuiwen committed
14
from lightx2v.server.metrics import monitor_cli
yihuiwen's avatar
yihuiwen committed
15
from lightx2v.utils.envs import *
16
from lightx2v.utils.profiler import *
17
18
19
from lightx2v.utils.registry_factory import RUNNER_REGISTER


20
21
22
23
24
25
26
27
28
29
def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


30
31
32
33
34
35
36
37
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)

38
    @ProfilingContext4DebugL2("Load models")
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        model = QwenImageTransformerModel(self.config)
        return model

    def load_text_encoder(self):
        text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLQwenImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
66
        self.run_dit = self._run_dit_local
67
        if self.config["task"] == "t2i":
68
69
70
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
71
72
73
        else:
            assert NotImplementedError

Watebear's avatar
Watebear committed
74
75
        self.model.set_scheduler(self.scheduler)

76
    @ProfilingContext4DebugL2("Run DiT")
77
78
79
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
80
        self.model.scheduler.prepare(self.input_info)
81
82
83
        latents, generator = self.run(total_steps)
        return latents, generator

84
    @ProfilingContext4DebugL2("Run Encoders")
85
    def _run_input_encoder_local_t2i(self):
86
87
        prompt = self.input_info.prompt
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
88
89
90
91
92
93
94
        torch.cuda.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

95
96
97
98
99
100
101
102
103
104
105
106
    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")
        if GET_RECORDER_MODE():
            width, height = img_ori.size
            monitor_cli.lightx2v_input_image_len.observe(width * height)
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
        self.input_info.original_size.append(img_ori.size)
        return img, img_ori

107
    @ProfilingContext4DebugL2("Run Encoders")
108
    def _run_input_encoder_local_i2i(self):
109
110
111
112
113
114
        image_paths_list = self.input_info.image_path.split(",")
        images_list = []
        for image_path in image_paths_list:
            _, image = self.read_image_input(image_path)
            images_list.append(image)

115
        prompt = self.input_info.prompt
116
117
118
119
120
121
122
        text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt)

        image_encoder_output_list = []
        for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
            image_encoder_output = self.run_vae_encoder(image=vae_image)
            image_encoder_output_list.append(image_encoder_output)

123
124
125
126
        torch.cuda.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
127
            "image_encoder_output": image_encoder_output_list,
128
129
        }

yihuiwen's avatar
yihuiwen committed
130
    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
131
    def run_text_encoder(self, text, image_list=None, neg_prompt=None):
yihuiwen's avatar
yihuiwen committed
132
133
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
134
        text_encoder_output = {}
135
        if self.config["task"] == "t2i":
136
            prompt_embeds, prompt_embeds_mask, _ = self.text_encoders[0].infer([text])
137
138
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
139
            if self.config["do_true_cfg"] and neg_prompt is not None:
140
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt])
141
142
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
143
        elif self.config["task"] == "i2i":
144
            prompt_embeds, prompt_embeds_mask, image_info = self.text_encoders[0].infer([text], image_list)
145
146
147
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
            text_encoder_output["image_info"] = image_info
148
            if self.config["do_true_cfg"] and neg_prompt is not None:
149
                neg_prompt_embeds, neg_prompt_embeds_mask, _ = self.text_encoders[0].infer([neg_prompt], image_list)
150
151
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
152
153
        return text_encoder_output

154
    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
155
    def run_vae_encoder(self, image):
156
        image_latents = self.vae.encode_vae_image(image, self.input_info)
157
158
        return {"image_latents": image_latents}

159
160
161
162
163
164
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

165
            with ProfilingContext4DebugL1("step_pre"):
166
167
                self.model.scheduler.step_pre(step_index=step_index)

168
            with ProfilingContext4DebugL1("🚀 infer_main"):
169
170
                self.model.infer(self.inputs)

171
            with ProfilingContext4DebugL1("step_post"):
172
173
174
175
176
177
178
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

179
    def set_target_shape(self):
180
181
        if not self.config["_auto_resize"]:
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
182
        else:
183
            width, height = self.input_info.original_size[-1]
184
185
            calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
            multiple_of = self.vae.vae_scale_factor * 2
186
187
188
189
            width = calculated_width // multiple_of * multiple_of
            height = calculated_height // multiple_of * multiple_of
            self.input_info.auto_width = width
            self.input_info.auto_hight = height
190

191
192
        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
193
194
        height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
195
        num_channels_latents = self.model.in_channels // 4
196
197
198
199
200
201
202
        self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)

    def set_img_shapes(self):
        if self.config["task"] == "t2i":
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
            img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
        elif self.config["task"] == "i2i":
203
204
205
206
            img_shapes = [[(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2)]]
            for image_height, image_width in self.inputs["text_encoder_output"]["image_info"]["vae_image_info_list"]:
                img_shapes[0].append((1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2))

207
        self.inputs["img_shapes"] = img_shapes
208
209

    def init_scheduler(self):
Watebear's avatar
Watebear committed
210
        self.scheduler = QwenImageScheduler(self.config)
211
212
213
214
215
216
217

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

218
    @ProfilingContext4DebugL2("Load models")
219
220
221
222
223
224
225
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None

yihuiwen's avatar
yihuiwen committed
226
227
228
229
230
231
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["QwenImageRunner"],
    )
232
    def run_vae_decoder(self, latents):
233
234
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae()
235
        images = self.vae.decode(latents, self.input_info)
236
237
238
239
240
241
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()
        return images

242
243
    def run_pipeline(self, input_info):
        self.input_info = input_info
244
245
246

        self.inputs = self.run_input_encoder()
        self.set_target_shape()
247
248
        self.set_img_shapes()

249
        latents, generator = self.run_dit()
250
251
        images = self.run_vae_decoder(latents)
        self.end_run()
252
253

        image = images[0]
254
        image.save(f"{input_info.save_result_path}")
255
256
257
258
259
260
261

        del latents, generator
        torch.cuda.empty_cache()
        gc.collect()

        # Return (images, audio) - audio is None for default runner
        return images, None