z_image_runner.py 17.7 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import gc
import math

import torch
import torchvision.transforms.functional as TF
from PIL import Image
from loguru import logger

from lightx2v.models.input_encoders.hf.z_image.qwen3_model import Qwen3Model_TextEncoder
from lightx2v.models.networks.lora_adapter import LoraAdapter
from lightx2v.models.networks.z_image.model import ZImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.z_image.scheduler import ZImageScheduler
from lightx2v.models.video_encoders.hf.z_image.vae import AutoencoderKLZImageVAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


def calculate_dimensions(target_area, ratio):
    width = math.sqrt(target_area * ratio)
    height = width / ratio

    width = round(width / 32) * 32
    height = round(height / 32) * 32

    return width, height, None


def build_z_image_model_with_lora(z_image_module, config, model_kwargs, lora_configs):
    lora_dynamic_apply = config.get("lora_dynamic_apply", False)

    if lora_dynamic_apply:
        lora_path = lora_configs[0]["path"]
        lora_strength = lora_configs[0]["strength"]
        model_kwargs["lora_path"] = lora_path
        model_kwargs["lora_strength"] = lora_strength
        model = z_image_module(**model_kwargs)
    else:
        assert not config.get("dit_quantized", False), "Online LoRA only for quantized models; merging LoRA is unsupported."
        assert not config.get("lazy_load", False), "Lazy load mode does not support LoRA merging."
        model = z_image_module(**model_kwargs)
        lora_adapter = LoraAdapter(model)
        lora_adapter.apply_lora(lora_configs)
    return model


@RUNNER_REGISTER("z_image")
class ZImageRunner(DefaultRunner):
    model_cpu_offload_seq = "text_encoder->transformer->vae"
    _callback_tensor_inputs = ["latents", "prompt_embeds"]

    def __init__(self, config):
        super().__init__(config)

    @ProfilingContext4DebugL2("Load models")
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    def load_transformer(self):
        z_image_model_kwargs = {
            "model_path": os.path.join(self.config["model_path"], "transformer"),
            "config": self.config,
            "device": self.init_device,
        }
        lora_configs = self.config.get("lora_configs")
        if not lora_configs:
            model = ZImageTransformerModel(**z_image_model_kwargs)
        else:
            model = build_z_image_model_with_lora(ZImageTransformerModel, self.config, z_image_model_kwargs, lora_configs)
        return model

    def load_text_encoder(self):
        text_encoder = Qwen3Model_TextEncoder(self.config)
        text_encoders = [text_encoder]
        return text_encoders

    def load_image_encoder(self):
        pass

    def load_vae(self):
        vae = AutoencoderKLZImageVAE(self.config)
        return vae

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
            self.model.set_scheduler(self.scheduler)
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
        self.run_dit = self._run_dit_local
        if self.config["task"] == "t2i":
            self.run_input_encoder = self._run_input_encoder_local_t2i
        elif self.config["task"] == "i2i":
            self.run_input_encoder = self._run_input_encoder_local_i2i
        else:
            assert NotImplementedError

    @ProfilingContext4DebugL2("Run DiT")
    def _run_dit_local(self, total_steps=None):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
            self.model.set_scheduler(self.scheduler)
        self.model.scheduler.prepare(self.input_info)
        latents, generator = self.run(total_steps)
        return latents, generator

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_t2i(self):
        prompt = self.input_info.prompt
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.text_encoders = self.load_text_encoder()
        text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.text_encoders[0]
        torch_device_module.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")

        # Get image dimensions
        width, height = img_ori.size

        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_image_len.observe(width * height)

        vae_scale_factor = self.config["vae_scale_factor"]
        vae_scale = vae_scale_factor * 2
        if height % vae_scale != 0 or width % vae_scale != 0:
            logger.warning(f"Image dimensions ({height}, {width}) are not divisible by {vae_scale}. Resizing to nearest valid dimensions.")
            # Resize to nearest valid dimensions
            new_height = (height // vae_scale) * vae_scale
            new_width = (width // vae_scale) * vae_scale
            if new_height == 0:
                new_height = vae_scale
            if new_width == 0:
                new_width = vae_scale
            img_ori = img_ori.resize((new_width, new_height), Image.Resampling.LANCZOS)
            logger.info(f"Resized image to ({new_height}, {new_width})")

        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
        self.input_info.original_size.append(img_ori.size)
        return img, img_ori

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_i2i(self):
        image_paths_list = self.input_info.image_path.split(",")
        images_list = []
        for image_path in image_paths_list:
            _, image = self.read_image_input(image_path)
            images_list.append(image)

        prompt = self.input_info.prompt
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.text_encoders = self.load_text_encoder()
        text_encoder_output = self.run_text_encoder(prompt, images_list, neg_prompt=self.input_info.negative_prompt)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.text_encoders[0]

        image_encoder_output_list = []
        for vae_image in text_encoder_output["image_info"]["vae_image_list"]:
            image_encoder_output = self.run_vae_encoder(image=vae_image)
            image_encoder_output_list.append(image_encoder_output)
        torch_device_module.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": image_encoder_output_list,
        }

    @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["ZImageRunner"])
    def run_text_encoder(self, text, image_list=None, neg_prompt=None):
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_input_prompt_len.observe(len(text))
        text_encoder_output = {}

        if self.config["task"] == "t2i":
            # T2I task: only text encoding
            # qwen3_model.infer always returns (embedding_list, image_info)
            # For t2i, image_info is empty dict {}
            prompt_embeds_list, _ = self.text_encoders[0].infer([text])
            prompt_embeds = prompt_embeds_list[0]  # Get first (and only) embedding
            # embedding_list[0] shape is (seq_len, hidden_dim), use shape[0] for sequence length
            self.input_info.txt_seq_lens = [prompt_embeds.shape[0]]
            text_encoder_output["prompt_embeds"] = prompt_embeds
            if self.config["enable_cfg"] and neg_prompt is not None:
                neg_prompt_embeds_list, _ = self.text_encoders[0].infer([neg_prompt])
                neg_prompt_embeds = neg_prompt_embeds_list[0]
                self.input_info.txt_seq_lens.append(neg_prompt_embeds.shape[0])
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
        elif self.config["task"] == "i2i":
            # I2I task: text encoding + image preprocessing
            if image_list is not None:
                prompt_embeds_list, image_info = self.text_encoders[0].infer([text], image_list)
                prompt_embeds = prompt_embeds_list[0]  # Get first (and only) embedding
                # embedding_list[0] shape is (seq_len, hidden_dim), use shape[0] for sequence length
                self.input_info.txt_seq_lens = [prompt_embeds.shape[0]]
                text_encoder_output["prompt_embeds"] = prompt_embeds
                text_encoder_output["image_info"] = image_info
                if self.config["enable_cfg"] and neg_prompt is not None:
                    neg_prompt_embeds_list, _ = self.text_encoders[0].infer([neg_prompt], image_list)
                    neg_prompt_embeds = neg_prompt_embeds_list[0]
                    self.input_info.txt_seq_lens.append(neg_prompt_embeds.shape[0])
                    text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
            else:
                # No images provided, treat as t2i
                prompt_embeds_list, _ = self.text_encoders[0].infer([text])
                prompt_embeds = prompt_embeds_list[0]
                self.input_info.txt_seq_lens = [prompt_embeds.shape[0]]
                text_encoder_output["prompt_embeds"] = prompt_embeds
                if self.config["enable_cfg"] and neg_prompt is not None:
                    neg_prompt_embeds_list, _ = self.text_encoders[0].infer([neg_prompt])
                    neg_prompt_embeds = neg_prompt_embeds_list[0]
                    self.input_info.txt_seq_lens.append(neg_prompt_embeds.shape[0])
                    text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
        else:
            # Default: t2i behavior
            prompt_embeds_list, _ = self.text_encoders[0].infer([text])
            prompt_embeds = prompt_embeds_list[0]
            self.input_info.txt_seq_lens = [prompt_embeds.shape[0]]
            text_encoder_output["prompt_embeds"] = prompt_embeds
            if self.config["enable_cfg"] and neg_prompt is not None:
                neg_prompt_embeds_list, _ = self.text_encoders[0].infer([neg_prompt])
                neg_prompt_embeds = neg_prompt_embeds_list[0]
                self.input_info.txt_seq_lens.append(neg_prompt_embeds.shape[0])
                text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds

        return text_encoder_output

    @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["ZImageRunner"])
    def run_vae_encoder(self, image):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae = self.load_vae()
        image_latents = self.vae.encode_vae_image(image.to(GET_DTYPE()))
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae
            torch_device_module.empty_cache()
            gc.collect()
        return {"image_latents": image_latents}

    def run(self, total_steps=None):
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
        for step_index in range(total_steps):
            logger.info(f"==> step_index: {step_index + 1} / {total_steps}")

            with ProfilingContext4DebugL1("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)

            with ProfilingContext4DebugL1("🚀 infer_main"):
                self.model.infer(self.inputs)

            with ProfilingContext4DebugL1("step_post"):
                self.model.scheduler.step_post()

            if self.progress_callback:
                self.progress_callback(((step_index + 1) / total_steps) * 100, 100)

        return self.model.scheduler.latents, self.model.scheduler.generator

    def get_input_target_shape(self):
        default_aspect_ratios = {
            "16:9": [1664, 928],
            "9:16": [928, 1664],
            "1:1": [1328, 1328],
            "4:3": [1472, 1140],
            "3:4": [768, 1024],
        }
        as_maps = self.config.get("aspect_ratios", {})
        as_maps.update(default_aspect_ratios)
        max_size = self.config.get("max_custom_size", 1664)
        min_size = self.config.get("min_custom_size", 256)

        if len(self.input_info.target_shape) == 2:
            height, width = self.input_info.target_shape
            height, width = int(height), int(width)
            if width > max_size or height > max_size:
                scale = max_size / max(width, height)
                width, height = int(width * scale), int(height * scale)
                logger.warning(f"Custom shape is too large, scaled to {width}x{height}")
            width, height = max(width, min_size), max(height, min_size)
            logger.info(f"Z Image Runner got custom shape: {width}x{height}")
            return (width, height)

        aspect_ratio = self.input_info.aspect_ratio if self.input_info.aspect_ratio else self.config.get("aspect_ratio", None)
        if aspect_ratio in as_maps:
            logger.info(f"Qwen Image Runner got aspect ratio: {aspect_ratio}")
            width, height = as_maps[aspect_ratio]
            return (width, height)
        logger.warning(f"Invalid aspect ratio: {aspect_ratio}, not in {as_maps.keys()}")

        raise NotImplementedError

    def set_target_shape(self):
        height, width = self.get_input_target_shape()

        # VAE applies 8x compression on images but we must also account for packing which requires
        # latent height and width to be divisible by 2.
        # Use config vae_scale_factor to match official pipeline calculation
        vae_scale_factor = self.config["vae_scale_factor"]
        height = 2 * (int(height) // (vae_scale_factor * 2))
        width = 2 * (int(width) // (vae_scale_factor * 2))
        num_channels_latents = self.config.get("num_channels_latents", 16)
        self.input_info.target_shape = (1, num_channels_latents, height, width)

    def set_img_shapes(self):
        if hasattr(self.input_info, "target_shape") and self.input_info.target_shape is not None:
            if len(self.input_info.target_shape) != 4:
                raise ValueError(f"target_shape must be 4D [B, C, H, W], got {len(self.input_info.target_shape)}D: {self.input_info.target_shape}")
            _, _, latent_height, latent_width = self.input_info.target_shape
        else:
            height, width = self.get_input_target_shape()

            vae_scale_factor = self.config["vae_scale_factor"]
            latent_height = 2 * (int(height) // (vae_scale_factor * 2))
            latent_width = 2 * (int(width) // (vae_scale_factor * 2))

        patch_size = self.config.get("patch_size", 2)
        patch_height = latent_height // patch_size
        patch_width = latent_width // patch_size

        image_shapes = [(1, patch_height, patch_width)]
        self.input_info.image_shapes = image_shapes

    def init_scheduler(self):
        self.scheduler = ZImageScheduler(self.config)

    def get_encoder_output_i2v(self):
        pass

    def run_image_encoder(self):
        pass

    @ProfilingContext4DebugL2("Load models")
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.vae = self.load_vae()

    @ProfilingContext4DebugL1(
        "Run VAE Decoder",
        recorder_mode=GET_RECORDER_MODE(),
        metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
        metrics_labels=["ZImageRunner"],
    )
    def run_vae_decoder(self, latents):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae = self.load_vae()
        images = self.vae.decode(latents, self.input_info)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae
            torch_device_module.empty_cache()
            gc.collect()
        return images

    @ProfilingContext4DebugL1("RUN pipeline")
    def run_pipeline(self, input_info):
        self.input_info = input_info

        self.inputs = self.run_input_encoder()
        # Store image_encoder_output in input_info for scheduler to access
        if self.config["task"] == "i2i" and "image_encoder_output" in self.inputs:
            self.input_info.image_encoder_output = self.inputs["image_encoder_output"]

        self.set_target_shape()
        self.set_img_shapes()
        logger.info(f"input_info: {self.input_info}")

        latents, generator = self.run_dit()
        images = self.run_vae_decoder(latents)
        self.end_run()

        if not input_info.return_result_tensor:
            image = images[0]
            image.save(input_info.save_result_path)
            logger.info(f"Image saved: {input_info.save_result_path}")

        del latents, generator
        torch_device_module.empty_cache()
        gc.collect()

        if input_info.return_result_tensor:
            return {"images": images}
        elif input_info.save_result_path is not None:
            return {"images": None}