Commit 79c3caa2 authored by Watebear's avatar Watebear Committed by GitHub
Browse files

feature: support qwen-image-edit(i2i) (#234)

* feature: support qwen-image-edit(i2i)

* bugfix
parent ce07fb15
{
"seed": 42,
"batchsize": 1,
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 50,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching",
"transformer_in_channels": 64,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"_auto_resize": true
}
...@@ -52,5 +52,8 @@ ...@@ -52,5 +52,8 @@
1.916 1.916
], ],
"vae_z_dim": 16, "vae_z_dim": 16,
"feature_caching": "NoCaching" "feature_caching": "NoCaching",
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34,
"_auto_resize": false
} }
...@@ -57,7 +57,7 @@ def main(): ...@@ -57,7 +57,7 @@ def main():
default="wan2.1", default="wan2.1",
) )
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "flf2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true") parser.add_argument("--use_prompt_enhancer", action="store_true")
......
import math
import os import os
import torch import torch
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
try:
from diffusers.image_processor import VaeImageProcessor
from transformers import Qwen2VLProcessor
except ImportError:
VaeImageProcessor = None
Qwen2VLProcessor = None
PREFERRED_QWENIMAGE_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height, None
class Qwen25_VLForConditionalGeneration_TextEncoder: class Qwen25_VLForConditionalGeneration_TextEncoder:
def __init__(self, config): def __init__(self, config):
...@@ -11,8 +49,12 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -11,8 +49,12 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer")) self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"))
self.tokenizer_max_length = 1024 self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode = config.prompt_template_encode
self.prompt_template_encode_start_idx = 34 self.prompt_template_encode_start_idx = config.prompt_template_encode_start_idx
if config.task == "i2i":
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(config.model_path, "processor"))
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
...@@ -24,19 +66,63 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -24,19 +66,63 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
split_result = torch.split(selected, valid_lengths.tolist(), dim=0) split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result return split_result
def infer(self, text): def preprocess_image(self, image):
image_size = image.size
width, height = image_size
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
height = height or calculated_height
width = width or calculated_width
multiple_of = self.config.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image_height, image_width = self.image_processor.get_default_height_width(image)
aspect_ratio = image_width / image_height
if self.config._auto_resize:
_, image_width, image_height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
image = self.image_processor.resize(image, image_height, image_width)
prompt_image = image
image = self.image_processor.preprocess(image, image_height, image_width)
image = image.unsqueeze(2)
return prompt_image, image, (image_height, image_width)
def infer(self, text, image=None):
template = self.prompt_template_encode template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in text] txt = [template.format(e) for e in text]
txt_tokens = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
encoder_hidden_states = self.text_encoder( if image is not None:
input_ids=txt_tokens.input_ids, prompt_image, image, image_info = self.preprocess_image(image)
attention_mask=txt_tokens.attention_mask, model_inputs = self.processor(
output_hidden_states=True, text=txt,
) images=prompt_image,
padding=True,
return_tensors="pt",
).to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
)
else:
prompt_image, image, image_info = None, None, None
model_inputs = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_hidden_states.hidden_states[-1] hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states] split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states]) max_seq_len = max([e.size(0) for e in split_hidden_states])
...@@ -51,4 +137,4 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -51,4 +137,4 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len) prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask return prompt_embeds, prompt_embeds_mask, image, image_info
...@@ -27,7 +27,6 @@ class QwenImagePreInfer: ...@@ -27,7 +27,6 @@ class QwenImagePreInfer:
guidance = guidance.to(hidden_states.dtype) * 1000 guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states) temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (hidden_states_0, temb, image_rotary_emb) return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (hidden_states_0, temb, image_rotary_emb)
...@@ -53,6 +53,12 @@ class QwenImageTransformerModel: ...@@ -53,6 +53,12 @@ class QwenImageTransformerModel:
def infer(self, inputs): def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index] t = self.scheduler.timesteps[self.scheduler.step_index]
latents = self.scheduler.latents latents = self.scheduler.latents
if self.config.task == "i2i":
image_latents = inputs["image_encoder_output"]["image_latents"]
latents_input = torch.cat([latents, image_latents], dim=1)
else:
latents_input = latents
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
img_shapes = self.scheduler.img_shapes img_shapes = self.scheduler.img_shapes
...@@ -60,9 +66,8 @@ class QwenImageTransformerModel: ...@@ -60,9 +66,8 @@ class QwenImageTransformerModel:
prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"] prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out = self.pre_infer.infer( hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out = self.pre_infer.infer(
hidden_states=latents, hidden_states=latents_input,
timestep=timestep / 1000, timestep=timestep / 1000,
guidance=self.scheduler.guidance, guidance=self.scheduler.guidance,
encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states_mask=prompt_embeds_mask,
...@@ -81,5 +86,7 @@ class QwenImageTransformerModel: ...@@ -81,5 +86,7 @@ class QwenImageTransformerModel:
) )
noise_pred = self.post_infer.infer(hidden_states, pre_infer_out[1]) noise_pred = self.post_infer.infer(hidden_states, pre_infer_out[1])
if self.config.task == "i2i":
noise_pred = noise_pred[:, : latents.size(1)]
self.scheduler.noise_pred = noise_pred self.scheduler.noise_pred = noise_pred
import gc import gc
import math
import torch import torch
from PIL import Image
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
...@@ -12,6 +14,16 @@ from lightx2v.utils.profiler import ProfilingContext ...@@ -12,6 +14,16 @@ from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
def calculate_dimensions(target_area, ratio):
width = math.sqrt(target_area * ratio)
height = width / ratio
width = round(width / 32) * 32
height = round(height / 32) * 32
return width, height, None
@RUNNER_REGISTER("qwen_image") @RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner): class QwenImageRunner(DefaultRunner):
model_cpu_offload_seq = "text_encoder->transformer->vae" model_cpu_offload_seq = "text_encoder->transformer->vae"
...@@ -51,12 +63,14 @@ class QwenImageRunner(DefaultRunner): ...@@ -51,12 +63,14 @@ class QwenImageRunner(DefaultRunner):
self.run_dit = self._run_dit_local self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i": if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_i2v self.run_input_encoder = self._run_input_encoder_local_t2i
elif self.config["task"] == "i2i":
self.run_input_encoder = self._run_input_encoder_local_i2i
else: else:
assert NotImplementedError assert NotImplementedError
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self): def _run_input_encoder_local_t2i(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt) text_encoder_output = self.run_text_encoder(prompt)
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -66,20 +80,57 @@ class QwenImageRunner(DefaultRunner): ...@@ -66,20 +80,57 @@ class QwenImageRunner(DefaultRunner):
"image_encoder_output": None, "image_encoder_output": None,
} }
def run_text_encoder(self, text): @ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2i(self):
image = Image.open(self.config["image_path"])
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, image)
image_encoder_output = self.run_vae_encoder(image=text_encoder_output["preprocessed_image"])
image_encoder_output["image_info"] = text_encoder_output["image_info"]
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": image_encoder_output,
}
def run_text_encoder(self, text, image=None):
text_encoder_output = {} text_encoder_output = {}
prompt_embeds, prompt_embeds_mask = self.text_encoders[0].infer([text]) if self.config["task"] == "t2i":
text_encoder_output["prompt_embeds"] = prompt_embeds prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text])
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
elif self.config["task"] == "i2i":
prompt_embeds, prompt_embeds_mask, preprocessed_image, image_info = self.text_encoders[0].infer([text], image)
text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
text_encoder_output["preprocessed_image"] = preprocessed_image
text_encoder_output["image_info"] = image_info
return text_encoder_output return text_encoder_output
def run_vae_encoder(self, image):
image_latents = self.vae.encode_vae_image(image)
return {"image_latents": image_latents}
def set_target_shape(self): def set_target_shape(self):
self.vae_scale_factor = self.vae.vae_scale_factor if getattr(self, "vae", None) else 8 if not self.config._auto_resize:
width, height = self.config.aspect_ratios[self.config.aspect_ratio] width, height = self.config.aspect_ratios[self.config.aspect_ratio]
else:
image = Image.open(self.config.image_path).convert("RGB")
width, height = image.size
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
height = height or calculated_height
width = width or calculated_width
multiple_of = self.vae.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
self.config.auto_width = width
self.config.auto_hight = height
# VAE applies 8x compression on images but we must also account for packing which requires # VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2. # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2)) height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
num_channels_latents = self.model.in_channels // 4 num_channels_latents = self.model.in_channels // 4
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width) self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width)
...@@ -96,9 +147,6 @@ class QwenImageRunner(DefaultRunner): ...@@ -96,9 +147,6 @@ class QwenImageRunner(DefaultRunner):
def run_image_encoder(self): def run_image_encoder(self):
pass pass
def run_vae_encoder(self):
pass
@ProfilingContext("Load models") @ProfilingContext("Load models")
def load_model(self): def load_model(self):
self.model = self.load_transformer() self.model = self.load_transformer()
......
...@@ -129,10 +129,7 @@ class QwenImageScheduler(BaseScheduler): ...@@ -129,10 +129,7 @@ class QwenImageScheduler(BaseScheduler):
def prepare_latents(self): def prepare_latents(self):
shape = self.config.target_shape shape = self.config.target_shape
width, height = self.config.aspect_ratios[self.config.aspect_ratio] width, height = shape[-1], shape[-2]
self.vae_scale_factor = self.config.vae_scale_factor if getattr(self, "vae", None) else 8
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype) latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype)
latents = self._pack_latents(latents, self.config.batchsize, self.config.num_channels_latents, height, width) latents = self._pack_latents(latents, self.config.batchsize, self.config.num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config.batchsize, height // 2, width // 2, self.device, self.dtype) latent_image_ids = self._prepare_latent_image_ids(self.config.batchsize, height // 2, width // 2, self.device, self.dtype)
...@@ -175,14 +172,23 @@ class QwenImageScheduler(BaseScheduler): ...@@ -175,14 +172,23 @@ class QwenImageScheduler(BaseScheduler):
guidance = None guidance = None
self.guidance = guidance self.guidance = guidance
def set_img_shapes(self): def set_img_shapes(self, inputs):
width, height = self.config.aspect_ratios[self.config.aspect_ratio] if self.config.task == "t2i":
self.img_shapes = [(1, height // self.config.vae_scale_factor // 2, width // self.config.vae_scale_factor // 2)] * self.config.batchsize width, height = self.config.aspect_ratios[self.config.aspect_ratio]
self.img_shapes = [(1, height // self.config.vae_scale_factor // 2, width // self.config.vae_scale_factor // 2)] * self.config.batchsize
def prepare(self, image_encoder_output): elif self.config.task == "i2i":
image_height, image_width = inputs["image_info"]
self.img_shapes = [
[
(1, self.config.auto_hight // self.config.vae_scale_factor // 2, self.config.auto_width // self.config.vae_scale_factor // 2),
(1, image_height // self.config.vae_scale_factor // 2, image_width // self.config.vae_scale_factor // 2),
]
]
def prepare(self, inputs):
self.prepare_latents() self.prepare_latents()
self.prepare_guidance() self.prepare_guidance()
self.set_img_shapes() self.set_img_shapes(inputs)
self.set_timesteps() self.set_timesteps()
def step_post(self): def step_post(self):
......
import json import json
import os import os
from typing import Optional
import torch import torch
...@@ -11,6 +12,18 @@ except ImportError: ...@@ -11,6 +12,18 @@ except ImportError:
VaeImageProcessor = None VaeImageProcessor = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class AutoencoderKLQwenImageVAE: class AutoencoderKLQwenImageVAE:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -19,27 +32,33 @@ class AutoencoderKLQwenImageVAE: ...@@ -19,27 +32,33 @@ class AutoencoderKLQwenImageVAE:
with open(os.path.join(config.model_path, "vae", "config.json"), "r") as f: with open(os.path.join(config.model_path, "vae", "config.json"), "r") as f:
vae_config = json.load(f) vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8 self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.device = torch.device("cuda")
self.latent_channels = config.vae_z_dim
@staticmethod @staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape batchsize, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires # VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2. # latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2)) height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.view(batchsize, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) latents = latents.reshape(batchsize, channels // (2 * 2), 1, height, width)
return latents return latents
@torch.no_grad() @torch.no_grad()
def decode(self, latents): def decode(self, latents):
width, height = self.config.aspect_ratios[self.config.aspect_ratio] if self.config.task == "t2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
elif self.config.task == "i2i":
width, height = self.config.auto_width, self.config.auto_hight
latents = self._unpack_latents(latents, height, width, self.config.vae_scale_factor) latents = self._unpack_latents(latents, height, width, self.config.vae_scale_factor)
latents = latents.to(self.dtype) latents = latents.to(self.dtype)
latents_mean = torch.tensor(self.config.vae_latents_mean).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents_mean = torch.tensor(self.config.vae_latents_mean).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
...@@ -48,3 +67,43 @@ class AutoencoderKLQwenImageVAE: ...@@ -48,3 +67,43 @@ class AutoencoderKLQwenImageVAE:
images = self.model.decode(latents, return_dict=False)[0][:, :, 0] images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil") images = self.image_processor.postprocess(images, output_type="pil")
return images return images
@staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
def _pack_latents(latents, batchsize, num_channels_latents, height, width):
latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [retrieve_latents(self.model.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0])]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax")
latents_mean = torch.tensor(self.model.config.latents_mean).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = torch.tensor(self.model.config.latents_std).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) / latents_std
return image_latents
def encode_vae_image(self, image):
num_channels_latents = self.config.transformer_in_channels // 4
image = image.to(self.device).to(self.dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=self.generator)
else:
image_latents = image
if self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] == 0:
# expand init_latents for batchsize
additional_image_per_prompt = self.config.batchsize // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] != 0:
raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config.batchsize} text prompts.")
else:
image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(image_latents, self.config.batchsize, num_channels_latents, image_latent_height, image_latent_width)
return image_latents
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \
--image_path input.jpg \
--save_video_path ${lightx2v_path}/save_results/qwen_image_i2i.png
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment