Unverified Commit 62789aa4 authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: refactor to adapt new config system and align precision for qwen-image...

[feat]: refactor to adapt new config system and align precision for qwen-image & qwen-image-edit (#383)

[feat]: refactor to adapt new config system and align precision for
qwen-image & qwen-image-edit
parent 0aaab832
......@@ -3,7 +3,6 @@
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 50,
"num_laysers": 60,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
......@@ -48,8 +47,6 @@
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"_auto_resize": true,
"cpu_offload": true,
"offload_granularity": "block",
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
......@@ -59,5 +56,10 @@
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
"attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {}
}
......@@ -70,8 +70,6 @@
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34,
"_auto_resize": false,
"cpu_offload": true,
"offload_granularity": "block",
"num_layers": 60,
"attention_out_dim": 3072,
"attention_dim_head": 128,
......@@ -81,5 +79,9 @@
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
"attn_type": "flash_attn3",
"do_true_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {}
}
......@@ -56,5 +56,8 @@
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
"attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"mm_config": {}
}
......@@ -79,5 +79,8 @@
56
],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3"
"attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"mm_config": {}
}
......@@ -131,6 +131,8 @@ class RMSWeightFP32(RMSWeight):
variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = input_tensor * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
......
......@@ -47,8 +47,8 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
def __init__(self, config):
self.config = config
self.tokenizer_max_length = 1024
self.prompt_template_encode = config.prompt_template_encode
self.prompt_template_encode_start_idx = config.prompt_template_encode_start_idx
self.prompt_template_encode = config["prompt_template_encode"]
self.prompt_template_encode_start_idx = config["prompt_template_encode_start_idx"]
self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
......@@ -60,11 +60,14 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.load()
def load(self):
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config.model_path, "text_encoder")).to(self.device).to(self.dtype)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config.model_path, "tokenizer"))
if self.config.task == "i2i":
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
if not self.cpu_offload:
self.text_encoder = self.text_encoder.to("cuda")
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer"))
if self.config["task"] == "i2i":
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(self.config.model_path, "processor"))
self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(self.config["model_path"], "processor"))
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
......@@ -81,23 +84,17 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
height = height or calculated_height
width = width or calculated_width
multiple_of = self.config.vae_scale_factor * 2
multiple_of = self.config["vae_scale_factor"] * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image_height, image_width = self.image_processor.get_default_height_width(image)
aspect_ratio = image_width / image_height
if self.config._auto_resize:
_, image_width, image_height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
image = self.image_processor.resize(image, image_height, image_width)
image = self.image_processor.resize(image, calculated_height, calculated_width)
prompt_image = image
image = self.image_processor.preprocess(image, image_height, image_width)
image = self.image_processor.preprocess(image, calculated_height, calculated_width)
image = image.unsqueeze(2)
return prompt_image, image, (image_height, image_width)
return prompt_image, image, (calculated_height, calculated_width)
@torch.no_grad()
def infer(self, text, image=None):
......@@ -110,12 +107,14 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if image is not None:
prompt_image, image, image_info = self.preprocess_image(image)
model_inputs = self.processor(
text=txt,
images=prompt_image,
padding=True,
return_tensors="pt",
).to(torch.device("cuda"))
encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
......@@ -133,6 +132,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
)
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
......@@ -144,10 +144,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len)
prompt_embeds = prompt_embeds.repeat(1, self.config["num_images_per_prompt"], 1)
prompt_embeds = prompt_embeds.view(self.config["batchsize"] * self.config["num_images_per_prompt"], seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config["num_images_per_prompt"], 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config["batchsize"] * self.config["num_images_per_prompt"], seq_len)
if self.cpu_offload:
self.text_encoder.to(torch.device("cpu"))
......
......@@ -163,7 +163,7 @@ class QwenImagePreInfer:
self.config = config
self.attention_kwargs = {}
self.cpu_offload = config.get("cpu_offload", False)
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(config.axes_dims_rope), scale_rope=True)
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(config["axes_dims_rope"]), scale_rope=True)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -174,6 +174,7 @@ class QwenImagePreInfer:
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = encoder_hidden_states.squeeze(0)
encoder_hidden_states = weights.txt_norm.apply(encoder_hidden_states)
encoder_hidden_states = weights.txt_in.apply(encoder_hidden_states)
timesteps_proj = get_timestep_embedding(timestep).to(torch.bfloat16)
......
......@@ -191,14 +191,14 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
# Process image stream - norm2 + MLP
img_normed2 = block_weight.img_norm2.apply(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = F.silu(block_weight.img_mlp.mlp_0.apply(img_modulated2.squeeze(0)))
img_mlp_output = F.gelu(block_weight.img_mlp.mlp_0.apply(img_modulated2.squeeze(0)), approximate="tanh")
img_mlp_output = block_weight.img_mlp.mlp_2.apply(img_mlp_output)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
txt_normed2 = block_weight.txt_norm2.apply(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_mlp_output = F.silu(block_weight.txt_mlp.mlp_0.apply(txt_modulated2.squeeze(0)))
txt_mlp_output = F.gelu(block_weight.txt_mlp.mlp_0.apply(txt_modulated2.squeeze(0)), approximate="tanh")
txt_mlp_output = block_weight.txt_mlp.mlp_2.apply(txt_mlp_output)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
......
import glob
import json
import os
......@@ -23,17 +24,18 @@ class QwenImageTransformerModel:
def __init__(self, config):
self.config = config
self.model_path = os.path.join(config.model_path, "transformer")
self.model_path = os.path.join(config["model_path"], "transformer")
self.cpu_offload = config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda")
with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f:
with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f:
transformer_config = json.load(f)
self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {}
self.dit_quantized = self.config["dit_quantized"]
self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default"
self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
self._init_infer_class()
self._init_weights()
......@@ -61,7 +63,7 @@ class QwenImageTransformerModel:
if weight_dict is None:
is_weight_loader = self._should_load_weights()
if is_weight_loader:
if not self.dit_quantized:
if not self.dit_quantized or self.weight_auto_quant:
# Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
......@@ -191,14 +193,14 @@ class QwenImageTransformerModel:
t = self.scheduler.timesteps[self.scheduler.step_index]
latents = self.scheduler.latents
if self.config.task == "i2i":
if self.config["task"] == "i2i":
image_latents = inputs["image_encoder_output"]["image_latents"]
latents_input = torch.cat([latents, image_latents], dim=1)
else:
latents_input = latents
timestep = t.expand(latents.shape[0]).to(latents.dtype)
img_shapes = self.scheduler.img_shapes
img_shapes = inputs["img_shapes"]
prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]
......@@ -226,7 +228,43 @@ class QwenImageTransformerModel:
noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0])
if self.config.task == "i2i":
if self.config["do_true_cfg"]:
neg_prompt_embeds = inputs["text_encoder_output"]["negative_prompt_embeds"]
neg_prompt_embeds_mask = inputs["text_encoder_output"]["negative_prompt_embeds_mask"]
negative_txt_seq_lens = neg_prompt_embeds_mask.sum(dim=1).tolist() if neg_prompt_embeds_mask is not None else None
neg_hidden_states, neg_encoder_hidden_states, _, neg_pre_infer_out = self.pre_infer.infer(
weights=self.pre_weight,
hidden_states=latents_input,
timestep=timestep / 1000,
guidance=self.scheduler.guidance,
encoder_hidden_states_mask=neg_prompt_embeds_mask,
encoder_hidden_states=neg_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
)
neg_encoder_hidden_states, neg_hidden_states = self.transformer_infer.infer(
block_weights=self.transformer_weights,
hidden_states=neg_hidden_states.unsqueeze(0),
encoder_hidden_states=neg_encoder_hidden_states.unsqueeze(0),
pre_infer_out=neg_pre_infer_out,
)
neg_noise_pred = self.post_infer.infer(self.post_weight, neg_hidden_states, neg_pre_infer_out[0])
if self.config["task"] == "i2i":
noise_pred = noise_pred[:, : latents.size(1)]
if self.config["do_true_cfg"]:
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
comb_pred = neg_noise_pred + self.config["true_cfg_scale"] * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
noise_pred = noise_pred[:, : latents.size(1)]
self.scheduler.noise_pred = noise_pred
......@@ -10,9 +10,11 @@ class QwenImagePostWeights(WeightModule):
super().__init__()
self.task = config["task"]
self.config = config
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
assert NotImplementedError
......
......@@ -12,9 +12,11 @@ class QwenImageTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"]
self.task = config["task"]
self.config = config
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num))
self.add_module("blocks", blocks)
......@@ -27,10 +29,11 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
......@@ -50,7 +53,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
LN_WEIGHT_REGISTER["Default"](eps=1e-6),
)
self.attn = QwenImageCrossAttention(
block_index=block_index, block_prefix="transformer_blocks", task=config.task, mm_type=mm_type, config=config, lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file
block_index=block_index, block_prefix="transformer_blocks", task=config["task"], mm_type=mm_type, config=config, lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file
)
self.add_module("attn", self.attn)
......@@ -62,7 +65,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
block_index=block_index,
block_prefix="transformer_blocks",
ffn_prefix="img_mlp",
task=config.task,
task=config["task"],
mm_type=mm_type,
config=config,
lazy_load=self.lazy_load,
......@@ -94,7 +97,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
block_index=block_index,
block_prefix="transformer_blocks",
ffn_prefix="txt_mlp",
task=config.task,
task=config["task"],
mm_type=mm_type,
config=config,
lazy_load=self.lazy_load,
......@@ -136,6 +139,7 @@ class QwenImageCrossAttention(WeightModule):
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.attn_type = config.get("attn_type", "flash_attn3")
self.heads = config["attention_out_dim"] // config["attention_dim_head"]
......
......@@ -2,7 +2,6 @@ import gc
import math
import torch
from PIL import Image
from loguru import logger
from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
......@@ -63,7 +62,6 @@ class QwenImageRunner(DefaultRunner):
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_t2i
elif self.config["task"] == "i2i":
......@@ -77,16 +75,14 @@ class QwenImageRunner(DefaultRunner):
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.scheduler.prepare(self.input_info)
latents, generator = self.run(total_steps)
self.end_run()
return latents, generator
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2i(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt)
prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
torch.cuda.empty_cache()
gc.collect()
return {
......@@ -96,9 +92,9 @@ class QwenImageRunner(DefaultRunner):
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2i(self):
_, image = self.read_image_input(self.config["image_path"])
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt, image)
_, image = self.read_image_input(self.input_info.image_path)
prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt, image, neg_prompt=self.input_info.negative_prompt)
image_encoder_output = self.run_vae_encoder(image=text_encoder_output["preprocessed_image"])
image_encoder_output["image_info"] = text_encoder_output["image_info"]
torch.cuda.empty_cache()
......@@ -109,7 +105,7 @@ class QwenImageRunner(DefaultRunner):
}
@ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
def run_text_encoder(self, text, image=None):
def run_text_encoder(self, text, image=None, neg_prompt=None):
if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_prompt_len.observe(len(text))
text_encoder_output = {}
......@@ -117,17 +113,25 @@ class QwenImageRunner(DefaultRunner):
prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text])
text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
if self.config["do_true_cfg"] and neg_prompt is not None:
neg_prompt_embeds, neg_prompt_embeds_mask, _, _ = self.text_encoders[0].infer([neg_prompt])
text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
elif self.config["task"] == "i2i":
prompt_embeds, prompt_embeds_mask, preprocessed_image, image_info = self.text_encoders[0].infer([text], image)
text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
text_encoder_output["preprocessed_image"] = preprocessed_image
text_encoder_output["image_info"] = image_info
if self.config["do_true_cfg"] and neg_prompt is not None:
neg_prompt_embeds, neg_prompt_embeds_mask, _, _ = self.text_encoders[0].infer([neg_prompt], image)
text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
return text_encoder_output
@ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
def run_vae_encoder(self, image):
image_latents = self.vae.encode_vae_image(image)
image_latents = self.vae.encode_vae_image(image, self.input_info)
return {"image_latents": image_latents}
def run(self, total_steps=None):
......@@ -151,26 +155,37 @@ class QwenImageRunner(DefaultRunner):
return self.model.scheduler.latents, self.model.scheduler.generator
def set_target_shape(self):
if not self.config._auto_resize:
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
if not self.config["_auto_resize"]:
width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
else:
image = Image.open(self.config.image_path).convert("RGB")
width, height = image.size
width, height = self.input_info.original_size
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
height = height or calculated_height
width = width or calculated_width
multiple_of = self.vae.vae_scale_factor * 2
width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of
self.config.auto_width = width
self.config.auto_hight = height
width = calculated_width // multiple_of * multiple_of
height = calculated_height // multiple_of * multiple_of
self.input_info.auto_width = width
self.input_info.auto_hight = height
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
num_channels_latents = self.model.in_channels // 4
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width)
self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)
def set_img_shapes(self):
if self.config["task"] == "t2i":
width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
elif self.config["task"] == "i2i":
image_height, image_width = self.inputs["image_encoder_output"]["image_info"]
img_shapes = [
[
(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2),
(1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2),
]
]
self.inputs["img_shapes"] = img_shapes
def init_scheduler(self):
self.scheduler = QwenImageScheduler(self.config)
......@@ -195,27 +210,29 @@ class QwenImageRunner(DefaultRunner):
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["QwenImageRunner"],
)
def _run_vae_decoder_local(self, latents, generator):
def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae()
images = self.vae.decode(latents)
images = self.vae.decode(latents, self.input_info)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
return images
def run_pipeline(self, save_image=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
def run_pipeline(self, input_info):
self.input_info = input_info
self.inputs = self.run_input_encoder()
self.set_target_shape()
self.set_img_shapes()
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents)
self.end_run()
images = self.run_vae_decoder(latents, generator)
image = images[0]
image.save(f"{self.config.save_result_path}")
image.save(f"{input_info.save_result_path}")
del latents, generator
torch.cuda.empty_cache()
......
import inspect
import json
import os
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
......@@ -80,14 +79,60 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional[Union[str, "torch.device"]] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
if isinstance(device, str):
device = torch.device(device)
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
print(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slightly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
class QwenImageScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.config = config
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config.model_path, "scheduler"))
with open(os.path.join(config.model_path, "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler"))
with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f)
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.guidance_scale = 1.0
......@@ -118,27 +163,29 @@ class QwenImageScheduler(BaseScheduler):
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(self):
shape = self.config.target_shape
def prepare_latents(self, input_info):
shape = input_info.target_shape
width, height = shape[-1], shape[-2]
latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype)
latents = self._pack_latents(latents, self.config.batchsize, self.config.num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config.batchsize, height // 2, width // 2, self.device, self.dtype)
latents = self._pack_latents(latents, self.config["batchsize"], self.config["num_channels_latents"], height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, self.device, self.dtype)
self.latents = latents
self.latent_image_ids = latent_image_ids
self.noise_pred = None
def set_timesteps(self):
sigmas = np.linspace(1.0, 1 / self.config.infer_steps, self.config.infer_steps)
sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"])
image_seq_len = self.latents.shape[1]
mu = calculate_shift(
image_seq_len,
......@@ -147,7 +194,7 @@ class QwenImageScheduler(BaseScheduler):
self.scheduler_config.get("base_shift", 0.5),
self.scheduler_config.get("max_shift", 1.15),
)
num_inference_steps = self.config.infer_steps
num_inference_steps = self.config["infer_steps"]
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
......@@ -165,30 +212,20 @@ class QwenImageScheduler(BaseScheduler):
def prepare_guidance(self):
# handle guidance
if self.config.guidance_embeds:
if self.config["guidance_embeds"]:
guidance = torch.full([1], self.guidance_scale, device=self.device, dtype=torch.float32)
guidance = guidance.expand(self.latents.shape[0])
else:
guidance = None
self.guidance = guidance
def set_img_shapes(self, inputs):
if self.config.task == "t2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
self.img_shapes = [(1, height // self.config.vae_scale_factor // 2, width // self.config.vae_scale_factor // 2)] * self.config.batchsize
elif self.config.task == "i2i":
image_height, image_width = inputs["image_info"]
self.img_shapes = [
[
(1, self.config.auto_hight // self.config.vae_scale_factor // 2, self.config.auto_width // self.config.vae_scale_factor // 2),
(1, image_height // self.config.vae_scale_factor // 2, image_width // self.config.vae_scale_factor // 2),
]
]
def prepare(self, inputs):
self.prepare_latents()
def prepare(self, input_info):
if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i":
self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed)
self.prepare_latents(input_info)
self.prepare_guidance()
self.set_img_shapes(inputs)
self.set_timesteps()
def step_post(self):
......
......@@ -35,16 +35,15 @@ class AutoencoderKLQwenImageVAE:
else:
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.latent_channels = config.vae_z_dim
self.latent_channels = config["vae_z_dim"]
self.load()
def load(self):
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config.model_path, "vae")).to(self.device).to(self.dtype)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config.vae_scale_factor * 2)
with open(os.path.join(self.config.model_path, "vae", "config.json"), "r") as f:
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config["model_path"], "vae")).to(self.device).to(self.dtype)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
with open(os.path.join(self.config["model_path"], "vae", "config.json"), "r") as f:
vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(self.config.seed)
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
......@@ -63,17 +62,17 @@ class AutoencoderKLQwenImageVAE:
return latents
@torch.no_grad()
def decode(self, latents):
def decode(self, latents, input_info):
if self.cpu_offload:
self.model.to(torch.device("cuda"))
if self.config.task == "t2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
elif self.config.task == "i2i":
width, height = self.config.auto_width, self.config.auto_hight
latents = self._unpack_latents(latents, height, width, self.config.vae_scale_factor)
if self.config["task"] == "t2i":
width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
elif self.config["task"] == "i2i":
width, height = input_info.auto_width, input_info.auto_hight
latents = self._unpack_latents(latents, height, width, self.config["vae_scale_factor"])
latents = latents.to(self.dtype)
latents_mean = torch.tensor(self.config.vae_latents_mean).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.config.vae_latents_std).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_mean = torch.tensor(self.config["vae_latents_mean"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.config["vae_latents_std"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil")
......@@ -97,33 +96,39 @@ class AutoencoderKLQwenImageVAE:
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax")
latents_mean = torch.tensor(self.model.config.latents_mean).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = torch.tensor(self.model.config.latents_std).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_mean = torch.tensor(self.model.config["latents_mean"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = torch.tensor(self.model.config["latents_std"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) / latents_std
return image_latents
@torch.no_grad()
def encode_vae_image(self, image):
def encode_vae_image(self, image, input_info):
if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i":
self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed)
if self.cpu_offload:
self.model.to(torch.device("cuda"))
num_channels_latents = self.config.transformer_in_channels // 4
num_channels_latents = self.config["transformer_in_channels"] // 4
image = image.to(self.model.device).to(self.dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=self.generator)
else:
image_latents = image
if self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] == 0:
if self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] == 0:
# expand init_latents for batchsize
additional_image_per_prompt = self.config.batchsize // image_latents.shape[0]
additional_image_per_prompt = self.config["batchsize"] // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] != 0:
raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config.batchsize} text prompts.")
elif self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] != 0:
raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config['batchsize']} text prompts.")
else:
image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(image_latents, self.config.batchsize, num_channels_latents, image_latent_height, image_latent_width)
image_latents = self._pack_latents(image_latents, self.config["batchsize"], num_channels_latents, image_latent_height, image_latent_width)
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch.cuda.empty_cache()
......
......@@ -103,6 +103,28 @@ class AnimateInputInfo:
target_shape: int = field(default_factory=int)
@dataclass
class T2IInputInfo:
seed: int = field(default_factory=int)
prompt: str = field(default_factory=str)
negative_prompt: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
# shape related
target_shape: int = field(default_factory=int)
@dataclass
class I2IInputInfo:
seed: int = field(default_factory=int)
prompt: str = field(default_factory=str)
negative_prompt: str = field(default_factory=str)
image_path: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
# shape related
target_shape: int = field(default_factory=int)
processed_image_size: int = field(default_factory=list)
def set_input_info(args):
if args.task == "t2v":
input_info = T2VInputInfo(
......@@ -161,10 +183,23 @@ def set_input_info(args):
save_result_path=args.save_result_path,
return_result_tensor=args.return_result_tensor,
)
elif args.task == "t2i":
input_info = T2IInputInfo(
seed=args.seed,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
save_result_path=args.save_result_path,
)
elif args.task == "i2i":
input_info = I2IInputInfo(
seed=args.seed,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
image_path=args.image_path,
save_result_path=args.save_result_path,
)
else:
raise ValueError(f"Unsupported task: {args.task}")
assert not (input_info.save_result_path and input_info.return_result_tensor), "save_result_path and return_result_tensor cannot be set at the same time"
return input_info
......
......@@ -34,6 +34,8 @@ python -m lightx2v.infer \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \
--image_path input.jpg \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png
--prompt "turn the style of the photo to vintage comic book" \
--negative_prompt " " \
--image_path /data/nvme2/wushuo/qwen-image/pie.png \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \
--seed 0
......@@ -34,7 +34,9 @@ python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_i2i_block.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \
--image_path input.jpg \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--prompt "turn the style of the photo to vintage comic book" \
--negative_prompt " " \
--image_path pie.png \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \
--seed 0
......@@ -34,5 +34,7 @@ python -m lightx2v.infer \
--task t2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic compositionUltra HD, 4K, cinematic composition.' \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.' \
--negative_prompt " " \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png \
--seed 42
......@@ -33,6 +33,8 @@ python -m lightx2v.infer \
--model_cls qwen_image \
--task t2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_t2i_block.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic compositionUltra HD, 4K, cinematic composition.' \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.' \
--negative_prompt " " \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png \
--seed 42
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment