qwen_image_runner.py 18.4 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import gc
import math

import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image
from loguru import logger

from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.lora_adapter import LoraAdapter
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


def build_qwen_image_model_with_lora(qwen_module, config, model_kwargs, lora_configs):
    lora_dynamic_apply = config.get("lora_dynamic_apply", False)

    if lora_dynamic_apply:
        lora_path = lora_configs[0]["path"]
        lora_strength = lora_configs[0]["strength"]
        model_kwargs["lora_path"] = lora_path
        model_kwargs["lora_strength"] = lora_strength
        model = qwen_module(**model_kwargs)
    else:
        assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported."
        assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging."
        model = qwen_module(**model_kwargs)
        lora_adapter = LoraAdapter(model)
        lora_adapter.apply_lora(lora_configs)
    return model


@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)
        self.is_layered = self.config.get("layered", False)
        if self.is_layered:
            self.layers = self.config.get("layers", 4)
        self.resolution = self.config.get("resolution", 1024)

        # Text encoder type: "lightllm_service", "lightllm_kernel", or default (baseline)
        self.text_encoder_type = config.get("text_encoder_type", "baseline")

        if self.text_encoder_type in ["lightllm_service", "lightllm_kernel"]:
            logger.info(f"Using LightLLM text encoder: {self.text_encoder_type}")

    @ProfilingContext4DebugL2("Load models")
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        qwen_image_model_kwargs = {
            "model_path": os.path.join(self.config["model_path"], "transformer"),
            "config": self.config,
            "device": self.init_device,
        }
        lora_configs = self.config.get("lora_configs")
        if not lora_configs:
            model = QwenImageTransformerModel(**qwen_image_model_kwargs)
        else:
            model = build_qwen_image_model_with_lora(QwenImageTransformerModel, self.config, qwen_image_model_kwargs, lora_configs)
        return model

    def load_text_encoder(self):
        """Load text encoder based on text_encoder_type configuration.

        Supported types:
        - "lightllm_service": LightLLM HTTP service mode
        - "lightllm_kernel": HuggingFace model with Triton kernel optimizations
        - "baseline" (default): HuggingFace baseline implementation
        """
        # Prepare encoder config by merging lightllm_config if present
        encoder_config = self.config.copy()
        lightllm_config = self.config.get("lightllm_config", {})
        encoder_config.update(lightllm_config)

        if self.text_encoder_type == "lightllm_service":
            from lightx2v.models.input_encoders.lightllm import LightLLMServiceTextEncoder

            logger.info("Loading LightLLM service-based text encoder")
            text_encoder = LightLLMServiceTextEncoder(encoder_config)
        elif self.text_encoder_type == "lightllm_kernel":
            from lightx2v.models.input_encoders.lightllm import LightLLMKernelTextEncoder

            logger.info("Loading LightLLM Kernel-optimized text encoder")
            text_encoder = LightLLMKernelTextEncoder(encoder_config)
        else:  # baseline or default
            logger.info("Loading HuggingFace baseline text encoder")
            text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)

        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLQwenImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
            self.model.set_scheduler(self.scheduler)
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
        self.run_dit = self._run_dit_local
        if self.config["task"] == "t2i":
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
        else:
            assert NotImplementedError

    @ProfilingContext4DebugL2("Run DiT")
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
            self.model.set_scheduler(self.scheduler)
        self.model.scheduler.prepare(self.input_info)
        latents, generator = self.run(total_steps)
        return latents, generator

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_t2i(self):
        prompt = self.input_info.prompt
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.text_encoders = self.load_text_encoder()
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.text_encoders[0]
        torch_device_module.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            if self.config.get("layered", False):
                img_ori = Image.open(img_path).convert("RGBA")
            else:
                img_ori = Image.open(img_path).convert("RGB")
        if GET_RECORDER_MODE():
            width, height = img_ori.size
            monitor_cli.lightx2v_input_image_len.observe(width * height)
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
        self.input_info.original_size.append(img_ori.size)
        return img, img_ori

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_i2i(self):
        image_paths_list = self.input_info.image_path.split(",")
        images_list = []
        for image_path in image_paths_list:
            _, image = self.read_image_input(image_path)
            images_list.append(image)

        prompt = self.input_info.prompt
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.text_encoders = self.load_text_encoder()
        text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            # Offload text encoder (service mode doesn't need offload)
            if self.text_encoder_type == "lightllm_service":
                pass  # Service mode: no local model to offload
            else:
                del self.text_encoders[0]
        image_encoder_output_list = []
        for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
            image_encoder_output = self.run_vae_encoder(image=vae_image)
            image_encoder_output_list.append(image_encoder_output)
        torch_device_module.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output_list,
        }

    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
    def run_text_encoder(self, text, image_list=None, neg_prompt=None):
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
        text_encoder_output = {}
        if self.config["task"] == "t2i":
            prompt_embeds, _, _ = self.text_encoders[0].infer([text])
            self.input_info.txt_seq_lens = [prompt_embeds.shape[1]]
            text_encoder_output["prompt_embeds"] = prompt_embeds
            if self.config["enable_cfg"] and neg_prompt is not None:
                neg_prompt_embeds, _, _ = self.text_encoders[0].infer([neg_prompt])
                self.input_info.txt_seq_lens.append(neg_prompt_embeds.shape[1])
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
        elif self.config["task"] == "i2i":
            prompt_embeds, _, image_info = self.text_encoders[0].infer([text], image_list)
            self.input_info.txt_seq_lens = [prompt_embeds.shape[1]]
            text_encoder_output["prompt_embeds"] = prompt_embeds
            text_encoder_output["image_info"] = image_info
            if self.config["enable_cfg"] and neg_prompt is not None:
                neg_prompt_embeds, _, _ = self.text_encoders[0].infer([neg_prompt], image_list)
                self.input_info.txt_seq_lens.append(neg_prompt_embeds.shape[1])
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
        return text_encoder_output

    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
    def run_vae_encoder(self, image):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae = self.load_vae()
        image_latents = self.vae.encode_vae_image(image.to(GET_DTYPE()))
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae
            torch_device_module.empty_cache()
            gc.collect()
        return {"image_latents": image_latents}

    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["QwenImageRunner"],
    )
    def run_vae_decoder(self, latents):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae = self.load_vae()
        images = self.vae.decode(latents, self.input_info)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae
            torch_device_module.empty_cache()
            gc.collect()
        return images

    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

            with ProfilingContext4DebugL1("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4DebugL1("🚀 infer_main"):
                self.model.infer(self.inputs)

            with ProfilingContext4DebugL1("step_post"):
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

    def get_custom_shape(self):
        default_aspect_ratios = {
            "16:9": [1664, 928],
            "9:16": [928, 1664],
            "1:1": [1328, 1328],
            "4:3": [1472, 1140],
            "3:4": [768, 1024],
        }
        as_maps = self.config.get("aspect_ratios", {})
        as_maps.update(default_aspect_ratios)
        max_size = self.config.get("max_custom_size", 1664)
        min_size = self.config.get("min_custom_size", 256)

        if len(self.input_info.target_shape) == 2:
            height, width = self.input_info.target_shape
            height, width = int(height), int(width)
            if width > max_size or height > max_size:
                scale = max_size / max(width, height)
                width, height = int(width * scale), int(height * scale)
                logger.warning(f"Custom shape is too large, scaled to {width}x{height}")
            width, height = max(width, min_size), max(height, min_size)
            logger.info(f"Qwen Image Runner got custom shape: {width}x{height}")
            return (width, height)

        aspect_ratio = self.input_info.aspect_ratio if self.input_info.aspect_ratio else self.config.get("aspect_ratio", None)
        if aspect_ratio in as_maps:
            logger.info(f"Qwen Image Runner got aspect ratio: {aspect_ratio}")
            width, height = as_maps[aspect_ratio]
            return (width, height)
        logger.warning(f"Invalid aspect ratio: {aspect_ratio}, not in {as_maps.keys()}")

        return None

    def set_target_shape(self):
        custom_shape = self.get_custom_shape()
        if custom_shape is not None:
            width, height = custom_shape
        else:
            width, height = self.input_info.original_size[-1]
            calculated_width, calculated_height, _ = calculate_dimensions(self.resolution * self.resolution, width / height)
            multiple_of = self.config["vae_scale_factor"] * 2
            width = calculated_width // multiple_of * multiple_of
            height = calculated_height // multiple_of * multiple_of
        logger.info(f"Qwen Image Runner set target shape: {width}x{height}")
        self.input_info.auto_width = width
        self.input_info.auto_height = height

        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
        height = 2 * (int(height) // (self.config["vae_scale_factor"] * 2))
        width = 2 * (int(width) // (self.config["vae_scale_factor"] * 2))
        num_channels_latents = self.config["in_channels"] // 4
        if not self.is_layered:
            self.input_info.target_shape = (1, 1, num_channels_latents, height, width)
        else:
            self.input_info.target_shape = (1, self.layers + 1, num_channels_latents, height, width)

    def set_img_shapes(self):
        width, height = self.input_info.auto_width, self.input_info.auto_height
        if self.config["task"] == "t2i":
            image_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * 1
        elif self.config["task"] == "i2i":
            if self.is_layered:
                image_shapes = [
                    [
                        *[(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2) for _ in range(self.layers + 1)],
                        (1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2),
                    ]
                ]
            else:
                image_shapes = [[(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)]]
                for image_height, image_width in self.inputs["text_encoder_output"]["image_info"]["vae_image_info_list"]:
                    image_shapes[0].append((1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2))
        self.input_info.image_shapes = image_shapes

    def init_scheduler(self):
        self.scheduler = QwenImageScheduler(self.config)

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

    @ProfilingContext4DebugL2("Load models")
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None

    @ProfilingContext4DebugL1("RUN pipeline")
    def run_pipeline(self, input_info):
        self.input_info = input_info

        self.inputs = self.run_input_encoder()
        self.set_target_shape()
        self.set_img_shapes()
        logger.info(f"input_info: {self.input_info}")
        latents, generator = self.run_dit()
        images = self.run_vae_decoder(latents)
        self.end_run()

        if not dist.is_initialized() or dist.get_rank() == 0:
            if not input_info.return_result_tensor:
                image_prefix = input_info.save_result_path.rsplit(".", 1)[0]
                image_suffix = input_info.save_result_path.rsplit(".", 1)[1] if len(input_info.save_result_path.rsplit(".", 1)) > 1 else "png"
                if isinstance(images[0], list) and len(images[0]) > 1:
                    for idx, image in enumerate(images[0]):
                        image.save(f"{image_prefix}_{idx:05d}.{image_suffix}")
                        logger.info(f"Image saved: {image_prefix}_{idx:05d}.{image_suffix}")
                else:
                    image = images[0]
                    image.save(f"{image_prefix}.{image_suffix}")
                    logger.info(f"Image saved: {image_prefix}.{image_suffix}")

        del latents, generator
        torch_device_module.empty_cache()
        gc.collect()

        if input_info.return_result_tensor:
            return {"images": images}
        elif input_info.save_result_path is not None:
            return {"images": None}