qwen_image_runner.py 10.4 KB
Newer Older
1
import gc
2
import math
3
4
5
6
7
8
9
10
11

import torch
from loguru import logger

from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
yihuiwen's avatar
yihuiwen committed
12
from lightx2v.server.metrics import monitor_cli
yihuiwen's avatar
yihuiwen committed
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.profiler import *
15
16
17
from lightx2v.utils.registry_factory import RUNNER_REGISTER


18
19
20
21
22
23
24
25
26
27
def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


28
29
30
31
32
33
34
35
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)

36
    @ProfilingContext4DebugL2("Load models")
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        model = QwenImageTransformerModel(self.config)
        return model

    def load_text_encoder(self):
        text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLQwenImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
64
        self.run_dit = self._run_dit_local
65
        if self.config["task"] == "t2i":
66
67
68
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
69
70
71
        else:
            assert NotImplementedError

Watebear's avatar
Watebear committed
72
73
        self.model.set_scheduler(self.scheduler)

74
    @ProfilingContext4DebugL2("Run DiT")
75
76
77
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
78
        self.model.scheduler.prepare(self.input_info)
79
80
81
        latents, generator = self.run(total_steps)
        return latents, generator

82
    @ProfilingContext4DebugL2("Run Encoders")
83
    def _run_input_encoder_local_t2i(self):
84
85
        prompt = self.input_info.prompt
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
86
87
88
89
90
91
92
        torch.cuda.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

93
    @ProfilingContext4DebugL2("Run Encoders")
94
    def _run_input_encoder_local_i2i(self):
95
96
97
        _, image = self.read_image_input(self.input_info.image_path)
        prompt = self.input_info.prompt
        text_encoder_output = self.run_text_encoder(prompt, image, neg_prompt=self.input_info.negative_prompt)
98
99
100
101
102
103
104
105
106
        image_encoder_output = self.run_vae_encoder(image=text_encoder_output["preprocessed_image"])
        image_encoder_output["image_info"] = text_encoder_output["image_info"]
        torch.cuda.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output,
        }

yihuiwen's avatar
yihuiwen committed
107
    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
108
    def run_text_encoder(self, text, image=None, neg_prompt=None):
yihuiwen's avatar
yihuiwen committed
109
110
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
111
        text_encoder_output = {}
112
113
114
115
        if self.config["task"] == "t2i":
            prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text])
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
116
117
118
119
            if self.config["do_true_cfg"] and neg_prompt is not None:
                neg_prompt_embeds, neg_prompt_embeds_mask, _, _ = self.text_encoders[0].infer([neg_prompt])
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
120
121
122
123
124
125
        elif self.config["task"] == "i2i":
            prompt_embeds, prompt_embeds_mask, preprocessed_image, image_info = self.text_encoders[0].infer([text], image)
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
            text_encoder_output["preprocessed_image"] = preprocessed_image
            text_encoder_output["image_info"] = image_info
126
127
128
129
            if self.config["do_true_cfg"] and neg_prompt is not None:
                neg_prompt_embeds, neg_prompt_embeds_mask, _, _ = self.text_encoders[0].infer([neg_prompt], image)
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
                text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
130
131
        return text_encoder_output

132
    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
133
    def run_vae_encoder(self, image):
134
        image_latents = self.vae.encode_vae_image(image, self.input_info)
135
136
        return {"image_latents": image_latents}

137
138
139
140
141
142
    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

143
            with ProfilingContext4DebugL1("step_pre"):
144
145
                self.model.scheduler.step_pre(step_index=step_index)

146
            with ProfilingContext4DebugL1("🚀 infer_main"):
147
148
                self.model.infer(self.inputs)

149
            with ProfilingContext4DebugL1("step_post"):
150
151
152
153
154
155
156
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

157
    def set_target_shape(self):
158
159
        if not self.config["_auto_resize"]:
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
160
        else:
161
            width, height = self.input_info.original_size
162
163
            calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
            multiple_of = self.vae.vae_scale_factor * 2
164
165
166
167
            width = calculated_width // multiple_of * multiple_of
            height = calculated_height // multiple_of * multiple_of
            self.input_info.auto_width = width
            self.input_info.auto_hight = height
168

169
170
        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
171
172
        height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
        width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
173
        num_channels_latents = self.model.in_channels // 4
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)

    def set_img_shapes(self):
        if self.config["task"] == "t2i":
            width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
            img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
        elif self.config["task"] == "i2i":
            image_height, image_width = self.inputs["image_encoder_output"]["image_info"]
            img_shapes = [
                [
                    (1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2),
                    (1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2),
                ]
            ]
        self.inputs["img_shapes"] = img_shapes
189
190

    def init_scheduler(self):
Watebear's avatar
Watebear committed
191
        self.scheduler = QwenImageScheduler(self.config)
192
193
194
195
196
197
198

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

199
    @ProfilingContext4DebugL2("Load models")
200
201
202
203
204
205
206
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None

yihuiwen's avatar
yihuiwen committed
207
208
209
210
211
212
    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["QwenImageRunner"],
    )
213
    def run_vae_decoder(self, latents):
214
215
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae()
216
        images = self.vae.decode(latents, self.input_info)
217
218
219
220
221
222
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()
        return images

223
224
    def run_pipeline(self, input_info):
        self.input_info = input_info
225
226
227

        self.inputs = self.run_input_encoder()
        self.set_target_shape()
228
229
        self.set_img_shapes()

230
        latents, generator = self.run_dit()
231
232
        images = self.run_vae_decoder(latents)
        self.end_run()
233
234

        image = images[0]
235
        image.save(f"{input_info.save_result_path}")
236
237
238
239
240
241
242

        del latents, generator
        torch.cuda.empty_cache()
        gc.collect()

        # Return (images, audio) - audio is None for default runner
        return images, None