Unverified Commit 62789aa4 authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: refactor to adapt new config system and align precision for qwen-image...

[feat]: refactor to adapt new config system and align precision for qwen-image & qwen-image-edit (#383)

[feat]: refactor to adapt new config system and align precision for
qwen-image & qwen-image-edit
parent 0aaab832
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
"num_channels_latents": 16, "num_channels_latents": 16,
"vae_scale_factor": 8, "vae_scale_factor": 8,
"infer_steps": 50, "infer_steps": 50,
"num_laysers": 60,
"guidance_embeds": false, "guidance_embeds": false,
"num_images_per_prompt": 1, "num_images_per_prompt": 1,
"vae_latents_mean": [ "vae_latents_mean": [
...@@ -48,8 +47,6 @@ ...@@ -48,8 +47,6 @@
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64, "prompt_template_encode_start_idx": 64,
"_auto_resize": true, "_auto_resize": true,
"cpu_offload": true,
"offload_granularity": "block",
"num_layers": 60, "num_layers": 60,
"attention_out_dim": 3072, "attention_out_dim": 3072,
"attention_dim_head": 128, "attention_dim_head": 128,
...@@ -59,5 +56,10 @@ ...@@ -59,5 +56,10 @@
56 56
], ],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3" "attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {}
} }
...@@ -70,8 +70,6 @@ ...@@ -70,8 +70,6 @@
"prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", "prompt_template_encode": "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 34, "prompt_template_encode_start_idx": 34,
"_auto_resize": false, "_auto_resize": false,
"cpu_offload": true,
"offload_granularity": "block",
"num_layers": 60, "num_layers": 60,
"attention_out_dim": 3072, "attention_out_dim": 3072,
"attention_dim_head": 128, "attention_dim_head": 128,
...@@ -81,5 +79,9 @@ ...@@ -81,5 +79,9 @@
56 56
], ],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3" "attn_type": "flash_attn3",
"do_true_cfg": false,
"cpu_offload": true,
"offload_granularity": "block",
"mm_config": {}
} }
...@@ -56,5 +56,8 @@ ...@@ -56,5 +56,8 @@
56 56
], ],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3" "attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"mm_config": {}
} }
...@@ -79,5 +79,8 @@ ...@@ -79,5 +79,8 @@
56 56
], ],
"_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]", "_comment_attn": "in [torch_sdpa, flash_attn3, sage_attn2]",
"attn_type": "flash_attn3" "attn_type": "flash_attn3",
"do_true_cfg": true,
"true_cfg_scale": 4.0,
"mm_config": {}
} }
...@@ -131,6 +131,8 @@ class RMSWeightFP32(RMSWeight): ...@@ -131,6 +131,8 @@ class RMSWeightFP32(RMSWeight):
variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = input_tensor * torch.rsqrt(variance + self.eps) hidden_states = input_tensor * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
if self.weight is not None: if self.weight is not None:
hidden_states = hidden_states * self.weight hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype) hidden_states = hidden_states.to(input_dtype)
......
...@@ -47,8 +47,8 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -47,8 +47,8 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.tokenizer_max_length = 1024 self.tokenizer_max_length = 1024
self.prompt_template_encode = config.prompt_template_encode self.prompt_template_encode = config["prompt_template_encode"]
self.prompt_template_encode_start_idx = config.prompt_template_encode_start_idx self.prompt_template_encode_start_idx = config["prompt_template_encode_start_idx"]
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload: if self.cpu_offload:
...@@ -60,11 +60,14 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -60,11 +60,14 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self.load() self.load()
def load(self): def load(self):
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config.model_path, "text_encoder")).to(self.device).to(self.dtype) self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(self.config["model_path"], "text_encoder"), torch_dtype=torch.bfloat16)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config.model_path, "tokenizer")) if not self.cpu_offload:
if self.config.task == "i2i": self.text_encoder = self.text_encoder.to("cuda")
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(self.config["model_path"], "tokenizer"))
if self.config["task"] == "i2i":
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2) self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(self.config.model_path, "processor")) self.processor = Qwen2VLProcessor.from_pretrained(os.path.join(self.config["model_path"], "processor"))
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool() bool_mask = mask.bool()
...@@ -81,23 +84,17 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -81,23 +84,17 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
height = height or calculated_height height = height or calculated_height
width = width or calculated_width width = width or calculated_width
multiple_of = self.config.vae_scale_factor * 2 multiple_of = self.config["vae_scale_factor"] * 2
width = width // multiple_of * multiple_of width = width // multiple_of * multiple_of
height = height // multiple_of * multiple_of height = height // multiple_of * multiple_of
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
image_height, image_width = self.image_processor.get_default_height_width(image) image = self.image_processor.resize(image, calculated_height, calculated_width)
aspect_ratio = image_width / image_height
if self.config._auto_resize:
_, image_width, image_height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS)
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
image = self.image_processor.resize(image, image_height, image_width)
prompt_image = image prompt_image = image
image = self.image_processor.preprocess(image, image_height, image_width) image = self.image_processor.preprocess(image, calculated_height, calculated_width)
image = image.unsqueeze(2) image = image.unsqueeze(2)
return prompt_image, image, (image_height, image_width)
return prompt_image, image, (calculated_height, calculated_width)
@torch.no_grad() @torch.no_grad()
def infer(self, text, image=None): def infer(self, text, image=None):
...@@ -110,12 +107,14 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -110,12 +107,14 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
if image is not None: if image is not None:
prompt_image, image, image_info = self.preprocess_image(image) prompt_image, image, image_info = self.preprocess_image(image)
model_inputs = self.processor( model_inputs = self.processor(
text=txt, text=txt,
images=prompt_image, images=prompt_image,
padding=True, padding=True,
return_tensors="pt", return_tensors="pt",
).to(torch.device("cuda")) ).to(torch.device("cuda"))
encoder_hidden_states = self.text_encoder( encoder_hidden_states = self.text_encoder(
input_ids=model_inputs.input_ids, input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask, attention_mask=model_inputs.attention_mask,
...@@ -133,6 +132,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -133,6 +132,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
) )
hidden_states = encoder_hidden_states.hidden_states[-1] hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states] split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
...@@ -144,10 +144,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder: ...@@ -144,10 +144,10 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds_mask = encoder_attention_mask prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, self.config.num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, self.config["num_images_per_prompt"], 1)
prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(self.config["batchsize"] * self.config["num_images_per_prompt"], seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config["num_images_per_prompt"], 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len) prompt_embeds_mask = prompt_embeds_mask.view(self.config["batchsize"] * self.config["num_images_per_prompt"], seq_len)
if self.cpu_offload: if self.cpu_offload:
self.text_encoder.to(torch.device("cpu")) self.text_encoder.to(torch.device("cpu"))
......
...@@ -163,7 +163,7 @@ class QwenImagePreInfer: ...@@ -163,7 +163,7 @@ class QwenImagePreInfer:
self.config = config self.config = config
self.attention_kwargs = {} self.attention_kwargs = {}
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(config.axes_dims_rope), scale_rope=True) self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(config["axes_dims_rope"]), scale_rope=True)
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -174,6 +174,7 @@ class QwenImagePreInfer: ...@@ -174,6 +174,7 @@ class QwenImagePreInfer:
timestep = timestep.to(hidden_states.dtype) timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = encoder_hidden_states.squeeze(0) encoder_hidden_states = encoder_hidden_states.squeeze(0)
encoder_hidden_states = weights.txt_norm.apply(encoder_hidden_states) encoder_hidden_states = weights.txt_norm.apply(encoder_hidden_states)
encoder_hidden_states = weights.txt_in.apply(encoder_hidden_states) encoder_hidden_states = weights.txt_in.apply(encoder_hidden_states)
timesteps_proj = get_timestep_embedding(timestep).to(torch.bfloat16) timesteps_proj = get_timestep_embedding(timestep).to(torch.bfloat16)
......
...@@ -191,14 +191,14 @@ class QwenImageTransformerInfer(BaseTransformerInfer): ...@@ -191,14 +191,14 @@ class QwenImageTransformerInfer(BaseTransformerInfer):
# Process image stream - norm2 + MLP # Process image stream - norm2 + MLP
img_normed2 = block_weight.img_norm2.apply(hidden_states) img_normed2 = block_weight.img_norm2.apply(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = F.silu(block_weight.img_mlp.mlp_0.apply(img_modulated2.squeeze(0))) img_mlp_output = F.gelu(block_weight.img_mlp.mlp_0.apply(img_modulated2.squeeze(0)), approximate="tanh")
img_mlp_output = block_weight.img_mlp.mlp_2.apply(img_mlp_output) img_mlp_output = block_weight.img_mlp.mlp_2.apply(img_mlp_output)
hidden_states = hidden_states + img_gate2 * img_mlp_output hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP # Process text stream - norm2 + MLP
txt_normed2 = block_weight.txt_norm2.apply(encoder_hidden_states) txt_normed2 = block_weight.txt_norm2.apply(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_mlp_output = F.silu(block_weight.txt_mlp.mlp_0.apply(txt_modulated2.squeeze(0))) txt_mlp_output = F.gelu(block_weight.txt_mlp.mlp_0.apply(txt_modulated2.squeeze(0)), approximate="tanh")
txt_mlp_output = block_weight.txt_mlp.mlp_2.apply(txt_mlp_output) txt_mlp_output = block_weight.txt_mlp.mlp_2.apply(txt_mlp_output)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
......
import glob
import json import json
import os import os
...@@ -23,17 +24,18 @@ class QwenImageTransformerModel: ...@@ -23,17 +24,18 @@ class QwenImageTransformerModel:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.model_path = os.path.join(config.model_path, "transformer") self.model_path = os.path.join(config["model_path"], "transformer")
self.cpu_offload = config.get("cpu_offload", False) self.cpu_offload = config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda") self.device = torch.device("cpu") if self.cpu_offload else torch.device("cuda")
with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f: with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f:
transformer_config = json.load(f) transformer_config = json.load(f)
self.in_channels = transformer_config["in_channels"] self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {} self.attention_kwargs = {}
self.dit_quantized = self.config["dit_quantized"] self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default"
self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
...@@ -61,7 +63,7 @@ class QwenImageTransformerModel: ...@@ -61,7 +63,7 @@ class QwenImageTransformerModel:
if weight_dict is None: if weight_dict is None:
is_weight_loader = self._should_load_weights() is_weight_loader = self._should_load_weights()
if is_weight_loader: if is_weight_loader:
if not self.dit_quantized: if not self.dit_quantized or self.weight_auto_quant:
# Load original weights # Load original weights
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
...@@ -191,14 +193,14 @@ class QwenImageTransformerModel: ...@@ -191,14 +193,14 @@ class QwenImageTransformerModel:
t = self.scheduler.timesteps[self.scheduler.step_index] t = self.scheduler.timesteps[self.scheduler.step_index]
latents = self.scheduler.latents latents = self.scheduler.latents
if self.config.task == "i2i": if self.config["task"] == "i2i":
image_latents = inputs["image_encoder_output"]["image_latents"] image_latents = inputs["image_encoder_output"]["image_latents"]
latents_input = torch.cat([latents, image_latents], dim=1) latents_input = torch.cat([latents, image_latents], dim=1)
else: else:
latents_input = latents latents_input = latents
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
img_shapes = self.scheduler.img_shapes img_shapes = inputs["img_shapes"]
prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"] prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"] prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]
...@@ -226,7 +228,43 @@ class QwenImageTransformerModel: ...@@ -226,7 +228,43 @@ class QwenImageTransformerModel:
noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0]) noise_pred = self.post_infer.infer(self.post_weight, hidden_states, pre_infer_out[0])
if self.config.task == "i2i": if self.config["do_true_cfg"]:
neg_prompt_embeds = inputs["text_encoder_output"]["negative_prompt_embeds"]
neg_prompt_embeds_mask = inputs["text_encoder_output"]["negative_prompt_embeds_mask"]
negative_txt_seq_lens = neg_prompt_embeds_mask.sum(dim=1).tolist() if neg_prompt_embeds_mask is not None else None
neg_hidden_states, neg_encoder_hidden_states, _, neg_pre_infer_out = self.pre_infer.infer(
weights=self.pre_weight,
hidden_states=latents_input,
timestep=timestep / 1000,
guidance=self.scheduler.guidance,
encoder_hidden_states_mask=neg_prompt_embeds_mask,
encoder_hidden_states=neg_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
)
neg_encoder_hidden_states, neg_hidden_states = self.transformer_infer.infer(
block_weights=self.transformer_weights,
hidden_states=neg_hidden_states.unsqueeze(0),
encoder_hidden_states=neg_encoder_hidden_states.unsqueeze(0),
pre_infer_out=neg_pre_infer_out,
)
neg_noise_pred = self.post_infer.infer(self.post_weight, neg_hidden_states, neg_pre_infer_out[0])
if self.config["task"] == "i2i":
noise_pred = noise_pred[:, : latents.size(1)] noise_pred = noise_pred[:, : latents.size(1)]
if self.config["do_true_cfg"]:
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
comb_pred = neg_noise_pred + self.config["true_cfg_scale"] * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
noise_pred = noise_pred[:, : latents.size(1)]
self.scheduler.noise_pred = noise_pred self.scheduler.noise_pred = noise_pred
...@@ -10,9 +10,11 @@ class QwenImagePostWeights(WeightModule): ...@@ -10,9 +10,11 @@ class QwenImagePostWeights(WeightModule):
super().__init__() super().__init__()
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
self.mm_type = config.get("dit_quant_scheme", "Default") if config["do_mm_calib"]:
if self.mm_type != "Default": self.mm_type = "Calib"
assert config.get("dit_quantized") is True else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
assert NotImplementedError assert NotImplementedError
......
...@@ -12,9 +12,11 @@ class QwenImageTransformerWeights(WeightModule): ...@@ -12,9 +12,11 @@ class QwenImageTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
self.mm_type = config.get("dit_quant_scheme", "Default") if config["do_mm_calib"]:
if self.mm_type != "Default": self.mm_type = "Calib"
assert config.get("dit_quantized") is True else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num)) blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num))
self.add_module("blocks", blocks) self.add_module("blocks", blocks)
...@@ -27,10 +29,11 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -27,10 +29,11 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors") lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else: else:
self.lazy_load_file = None self.lazy_load_file = None
...@@ -50,7 +53,7 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -50,7 +53,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
LN_WEIGHT_REGISTER["Default"](eps=1e-6), LN_WEIGHT_REGISTER["Default"](eps=1e-6),
) )
self.attn = QwenImageCrossAttention( self.attn = QwenImageCrossAttention(
block_index=block_index, block_prefix="transformer_blocks", task=config.task, mm_type=mm_type, config=config, lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file block_index=block_index, block_prefix="transformer_blocks", task=config["task"], mm_type=mm_type, config=config, lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file
) )
self.add_module("attn", self.attn) self.add_module("attn", self.attn)
...@@ -62,7 +65,7 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -62,7 +65,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
block_index=block_index, block_index=block_index,
block_prefix="transformer_blocks", block_prefix="transformer_blocks",
ffn_prefix="img_mlp", ffn_prefix="img_mlp",
task=config.task, task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
...@@ -94,7 +97,7 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -94,7 +97,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
block_index=block_index, block_index=block_index,
block_prefix="transformer_blocks", block_prefix="transformer_blocks",
ffn_prefix="txt_mlp", ffn_prefix="txt_mlp",
task=config.task, task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
...@@ -136,6 +139,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -136,6 +139,7 @@ class QwenImageCrossAttention(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.attn_type = config.get("attn_type", "flash_attn3") self.attn_type = config.get("attn_type", "flash_attn3")
self.heads = config["attention_out_dim"] // config["attention_dim_head"] self.heads = config["attention_out_dim"] // config["attention_dim_head"]
......
...@@ -2,7 +2,6 @@ import gc ...@@ -2,7 +2,6 @@ import gc
import math import math
import torch import torch
from PIL import Image
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
...@@ -63,7 +62,6 @@ class QwenImageRunner(DefaultRunner): ...@@ -63,7 +62,6 @@ class QwenImageRunner(DefaultRunner):
elif self.config.get("lazy_load", False): elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False) assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i": if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_t2i self.run_input_encoder = self._run_input_encoder_local_t2i
elif self.config["task"] == "i2i": elif self.config["task"] == "i2i":
...@@ -77,16 +75,14 @@ class QwenImageRunner(DefaultRunner): ...@@ -77,16 +75,14 @@ class QwenImageRunner(DefaultRunner):
def _run_dit_local(self, total_steps=None): def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.model.scheduler.prepare(self.input_info)
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
latents, generator = self.run(total_steps) latents, generator = self.run(total_steps)
self.end_run()
return latents, generator return latents, generator
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2i(self): def _run_input_encoder_local_t2i(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt) text_encoder_output = self.run_text_encoder(prompt, neg_prompt=self.input_info.negative_prompt)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return { return {
...@@ -96,9 +92,9 @@ class QwenImageRunner(DefaultRunner): ...@@ -96,9 +92,9 @@ class QwenImageRunner(DefaultRunner):
@ProfilingContext4DebugL2("Run Encoders") @ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2i(self): def _run_input_encoder_local_i2i(self):
_, image = self.read_image_input(self.config["image_path"]) _, image = self.read_image_input(self.input_info.image_path)
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.input_info.prompt
text_encoder_output = self.run_text_encoder(prompt, image) text_encoder_output = self.run_text_encoder(prompt, image, neg_prompt=self.input_info.negative_prompt)
image_encoder_output = self.run_vae_encoder(image=text_encoder_output["preprocessed_image"]) image_encoder_output = self.run_vae_encoder(image=text_encoder_output["preprocessed_image"])
image_encoder_output["image_info"] = text_encoder_output["image_info"] image_encoder_output["image_info"] = text_encoder_output["image_info"]
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -109,7 +105,7 @@ class QwenImageRunner(DefaultRunner): ...@@ -109,7 +105,7 @@ class QwenImageRunner(DefaultRunner):
} }
@ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"]) @ProfilingContext4DebugL1("Run Text Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_text_encode_duration, metrics_labels=["QwenImageRunner"])
def run_text_encoder(self, text, image=None): def run_text_encoder(self, text, image=None, neg_prompt=None):
if GET_RECORDER_MODE(): if GET_RECORDER_MODE():
monitor_cli.lightx2v_input_prompt_len.observe(len(text)) monitor_cli.lightx2v_input_prompt_len.observe(len(text))
text_encoder_output = {} text_encoder_output = {}
...@@ -117,17 +113,25 @@ class QwenImageRunner(DefaultRunner): ...@@ -117,17 +113,25 @@ class QwenImageRunner(DefaultRunner):
prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text]) prompt_embeds, prompt_embeds_mask, _, _ = self.text_encoders[0].infer([text])
text_encoder_output["prompt_embeds"] = prompt_embeds text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
if self.config["do_true_cfg"] and neg_prompt is not None:
neg_prompt_embeds, neg_prompt_embeds_mask, _, _ = self.text_encoders[0].infer([neg_prompt])
text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
elif self.config["task"] == "i2i": elif self.config["task"] == "i2i":
prompt_embeds, prompt_embeds_mask, preprocessed_image, image_info = self.text_encoders[0].infer([text], image) prompt_embeds, prompt_embeds_mask, preprocessed_image, image_info = self.text_encoders[0].infer([text], image)
text_encoder_output["prompt_embeds"] = prompt_embeds text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
text_encoder_output["preprocessed_image"] = preprocessed_image text_encoder_output["preprocessed_image"] = preprocessed_image
text_encoder_output["image_info"] = image_info text_encoder_output["image_info"] = image_info
if self.config["do_true_cfg"] and neg_prompt is not None:
neg_prompt_embeds, neg_prompt_embeds_mask, _, _ = self.text_encoders[0].infer([neg_prompt], image)
text_encoder_output["negative_prompt_embeds"] = neg_prompt_embeds
text_encoder_output["negative_prompt_embeds_mask"] = neg_prompt_embeds_mask
return text_encoder_output return text_encoder_output
@ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"]) @ProfilingContext4DebugL1("Run VAE Encoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration, metrics_labels=["QwenImageRunner"])
def run_vae_encoder(self, image): def run_vae_encoder(self, image):
image_latents = self.vae.encode_vae_image(image) image_latents = self.vae.encode_vae_image(image, self.input_info)
return {"image_latents": image_latents} return {"image_latents": image_latents}
def run(self, total_steps=None): def run(self, total_steps=None):
...@@ -151,26 +155,37 @@ class QwenImageRunner(DefaultRunner): ...@@ -151,26 +155,37 @@ class QwenImageRunner(DefaultRunner):
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents, self.model.scheduler.generator
def set_target_shape(self): def set_target_shape(self):
if not self.config._auto_resize: if not self.config["_auto_resize"]:
width, height = self.config.aspect_ratios[self.config.aspect_ratio] width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
else: else:
image = Image.open(self.config.image_path).convert("RGB") width, height = self.input_info.original_size
width, height = image.size
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height) calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, width / height)
height = height or calculated_height
width = width or calculated_width
multiple_of = self.vae.vae_scale_factor * 2 multiple_of = self.vae.vae_scale_factor * 2
width = width // multiple_of * multiple_of width = calculated_width // multiple_of * multiple_of
height = height // multiple_of * multiple_of height = calculated_height // multiple_of * multiple_of
self.config.auto_width = width self.input_info.auto_width = width
self.config.auto_hight = height self.input_info.auto_hight = height
# VAE applies 8x compression on images but we must also account for packing which requires # VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2. # latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae.vae_scale_factor * 2)) height = 2 * (int(height) // (self.vae.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae.vae_scale_factor * 2))
num_channels_latents = self.model.in_channels // 4 num_channels_latents = self.model.in_channels // 4
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width) self.input_info.target_shape = (self.config["batchsize"], 1, num_channels_latents, height, width)
def set_img_shapes(self):
if self.config["task"] == "t2i":
width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
img_shapes = [(1, height // self.config["vae_scale_factor"] // 2, width // self.config["vae_scale_factor"] // 2)] * self.config["batchsize"]
elif self.config["task"] == "i2i":
image_height, image_width = self.inputs["image_encoder_output"]["image_info"]
img_shapes = [
[
(1, self.input_info.auto_hight // self.config["vae_scale_factor"] // 2, self.input_info.auto_width // self.config["vae_scale_factor"] // 2),
(1, image_height // self.config["vae_scale_factor"] // 2, image_width // self.config["vae_scale_factor"] // 2),
]
]
self.inputs["img_shapes"] = img_shapes
def init_scheduler(self): def init_scheduler(self):
self.scheduler = QwenImageScheduler(self.config) self.scheduler = QwenImageScheduler(self.config)
...@@ -195,27 +210,29 @@ class QwenImageRunner(DefaultRunner): ...@@ -195,27 +210,29 @@ class QwenImageRunner(DefaultRunner):
metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_func=monitor_cli.lightx2v_run_vae_decode_duration,
metrics_labels=["QwenImageRunner"], metrics_labels=["QwenImageRunner"],
) )
def _run_vae_decoder_local(self, latents, generator): def run_vae_decoder(self, latents):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae() self.vae_decoder = self.load_vae()
images = self.vae.decode(latents) images = self.vae.decode(latents, self.input_info)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder del self.vae_decoder
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
return images return images
def run_pipeline(self, save_image=True): def run_pipeline(self, input_info):
if self.config["use_prompt_enhancer"]: self.input_info = input_info
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
self.set_target_shape() self.set_target_shape()
self.set_img_shapes()
latents, generator = self.run_dit() latents, generator = self.run_dit()
images = self.run_vae_decoder(latents)
self.end_run()
images = self.run_vae_decoder(latents, generator)
image = images[0] image = images[0]
image.save(f"{self.config.save_result_path}") image.save(f"{input_info.save_result_path}")
del latents, generator del latents, generator
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
import inspect import inspect
import json import json
import os import os
from typing import List, Optional, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
...@@ -80,14 +79,60 @@ def retrieve_timesteps( ...@@ -80,14 +79,60 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
device: Optional[Union[str, "torch.device"]] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
is always created on the CPU.
"""
# device on which tensor is created defaults to device
if isinstance(device, str):
device = torch.device(device)
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
print(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slightly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
# make sure generator list of length 1 is treated like a non-list
if isinstance(generator, list) and len(generator) == 1:
generator = generator[0]
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
class QwenImageScheduler(BaseScheduler): class QwenImageScheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config.model_path, "scheduler")) self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler"))
with open(os.path.join(config.model_path, "scheduler", "scheduler_config.json"), "r") as f: with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f) self.scheduler_config = json.load(f)
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.guidance_scale = 1.0 self.guidance_scale = 1.0
...@@ -118,27 +163,29 @@ class QwenImageScheduler(BaseScheduler): ...@@ -118,27 +163,29 @@ class QwenImageScheduler(BaseScheduler):
@staticmethod @staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype): def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3) latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)
return latent_image_ids.to(device=device, dtype=dtype) return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(self): def prepare_latents(self, input_info):
shape = self.config.target_shape shape = input_info.target_shape
width, height = shape[-1], shape[-2] width, height = shape[-1], shape[-2]
latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype) latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype)
latents = self._pack_latents(latents, self.config.batchsize, self.config.num_channels_latents, height, width) latents = self._pack_latents(latents, self.config["batchsize"], self.config["num_channels_latents"], height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config.batchsize, height // 2, width // 2, self.device, self.dtype) latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, self.device, self.dtype)
self.latents = latents self.latents = latents
self.latent_image_ids = latent_image_ids self.latent_image_ids = latent_image_ids
self.noise_pred = None self.noise_pred = None
def set_timesteps(self): def set_timesteps(self):
sigmas = np.linspace(1.0, 1 / self.config.infer_steps, self.config.infer_steps) sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"])
image_seq_len = self.latents.shape[1] image_seq_len = self.latents.shape[1]
mu = calculate_shift( mu = calculate_shift(
image_seq_len, image_seq_len,
...@@ -147,7 +194,7 @@ class QwenImageScheduler(BaseScheduler): ...@@ -147,7 +194,7 @@ class QwenImageScheduler(BaseScheduler):
self.scheduler_config.get("base_shift", 0.5), self.scheduler_config.get("base_shift", 0.5),
self.scheduler_config.get("max_shift", 1.15), self.scheduler_config.get("max_shift", 1.15),
) )
num_inference_steps = self.config.infer_steps num_inference_steps = self.config["infer_steps"]
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
num_inference_steps, num_inference_steps,
...@@ -165,30 +212,20 @@ class QwenImageScheduler(BaseScheduler): ...@@ -165,30 +212,20 @@ class QwenImageScheduler(BaseScheduler):
def prepare_guidance(self): def prepare_guidance(self):
# handle guidance # handle guidance
if self.config.guidance_embeds: if self.config["guidance_embeds"]:
guidance = torch.full([1], self.guidance_scale, device=self.device, dtype=torch.float32) guidance = torch.full([1], self.guidance_scale, device=self.device, dtype=torch.float32)
guidance = guidance.expand(self.latents.shape[0]) guidance = guidance.expand(self.latents.shape[0])
else: else:
guidance = None guidance = None
self.guidance = guidance self.guidance = guidance
def set_img_shapes(self, inputs): def prepare(self, input_info):
if self.config.task == "t2i": if self.config["task"] == "i2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio] self.generator = torch.Generator().manual_seed(input_info.seed)
self.img_shapes = [(1, height // self.config.vae_scale_factor // 2, width // self.config.vae_scale_factor // 2)] * self.config.batchsize elif self.config["task"] == "t2i":
elif self.config.task == "i2i": self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed)
image_height, image_width = inputs["image_info"] self.prepare_latents(input_info)
self.img_shapes = [
[
(1, self.config.auto_hight // self.config.vae_scale_factor // 2, self.config.auto_width // self.config.vae_scale_factor // 2),
(1, image_height // self.config.vae_scale_factor // 2, image_width // self.config.vae_scale_factor // 2),
]
]
def prepare(self, inputs):
self.prepare_latents()
self.prepare_guidance() self.prepare_guidance()
self.set_img_shapes(inputs)
self.set_timesteps() self.set_timesteps()
def step_post(self): def step_post(self):
......
...@@ -35,16 +35,15 @@ class AutoencoderKLQwenImageVAE: ...@@ -35,16 +35,15 @@ class AutoencoderKLQwenImageVAE:
else: else:
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.latent_channels = config.vae_z_dim self.latent_channels = config["vae_z_dim"]
self.load() self.load()
def load(self): def load(self):
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config.model_path, "vae")).to(self.device).to(self.dtype) self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config["model_path"], "vae")).to(self.device).to(self.dtype)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config.vae_scale_factor * 2) self.image_processor = VaeImageProcessor(vae_scale_factor=self.config["vae_scale_factor"] * 2)
with open(os.path.join(self.config.model_path, "vae", "config.json"), "r") as f: with open(os.path.join(self.config["model_path"], "vae", "config.json"), "r") as f:
vae_config = json.load(f) vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8 self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(self.config.seed)
@staticmethod @staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor): def _unpack_latents(latents, height, width, vae_scale_factor):
...@@ -63,17 +62,17 @@ class AutoencoderKLQwenImageVAE: ...@@ -63,17 +62,17 @@ class AutoencoderKLQwenImageVAE:
return latents return latents
@torch.no_grad() @torch.no_grad()
def decode(self, latents): def decode(self, latents, input_info):
if self.cpu_offload: if self.cpu_offload:
self.model.to(torch.device("cuda")) self.model.to(torch.device("cuda"))
if self.config.task == "t2i": if self.config["task"] == "t2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio] width, height = self.config["aspect_ratios"][self.config["aspect_ratio"]]
elif self.config.task == "i2i": elif self.config["task"] == "i2i":
width, height = self.config.auto_width, self.config.auto_hight width, height = input_info.auto_width, input_info.auto_hight
latents = self._unpack_latents(latents, height, width, self.config.vae_scale_factor) latents = self._unpack_latents(latents, height, width, self.config["vae_scale_factor"])
latents = latents.to(self.dtype) latents = latents.to(self.dtype)
latents_mean = torch.tensor(self.config.vae_latents_mean).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents_mean = torch.tensor(self.config["vae_latents_mean"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.config.vae_latents_std).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents_std = 1.0 / torch.tensor(self.config["vae_latents_std"]).view(1, self.config["vae_z_dim"], 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean latents = latents / latents_std + latents_mean
images = self.model.decode(latents, return_dict=False)[0][:, :, 0] images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil") images = self.image_processor.postprocess(images, output_type="pil")
...@@ -97,33 +96,39 @@ class AutoencoderKLQwenImageVAE: ...@@ -97,33 +96,39 @@ class AutoencoderKLQwenImageVAE:
image_latents = torch.cat(image_latents, dim=0) image_latents = torch.cat(image_latents, dim=0)
else: else:
image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax") image_latents = retrieve_latents(self.model.encode(image), generator=generator, sample_mode="argmax")
latents_mean = torch.tensor(self.model.config.latents_mean).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype) latents_mean = torch.tensor(self.model.config["latents_mean"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = torch.tensor(self.model.config.latents_std).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype) latents_std = torch.tensor(self.model.config["latents_std"]).view(1, self.latent_channels, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) / latents_std image_latents = (image_latents - latents_mean) / latents_std
return image_latents return image_latents
@torch.no_grad() @torch.no_grad()
def encode_vae_image(self, image): def encode_vae_image(self, image, input_info):
if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i":
self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed)
if self.cpu_offload: if self.cpu_offload:
self.model.to(torch.device("cuda")) self.model.to(torch.device("cuda"))
num_channels_latents = self.config.transformer_in_channels // 4 num_channels_latents = self.config["transformer_in_channels"] // 4
image = image.to(self.model.device).to(self.dtype) image = image.to(self.model.device).to(self.dtype)
if image.shape[1] != self.latent_channels: if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=self.generator) image_latents = self._encode_vae_image(image=image, generator=self.generator)
else: else:
image_latents = image image_latents = image
if self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] == 0: if self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] == 0:
# expand init_latents for batchsize # expand init_latents for batchsize
additional_image_per_prompt = self.config.batchsize // image_latents.shape[0] additional_image_per_prompt = self.config["batchsize"] // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif self.config.batchsize > image_latents.shape[0] and self.config.batchsize % image_latents.shape[0] != 0: elif self.config["batchsize"] > image_latents.shape[0] and self.config["batchsize"] % image_latents.shape[0] != 0:
raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config.batchsize} text prompts.") raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {self.config['batchsize']} text prompts.")
else: else:
image_latents = torch.cat([image_latents], dim=0) image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:] image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(image_latents, self.config.batchsize, num_channels_latents, image_latent_height, image_latent_width) image_latents = self._pack_latents(image_latents, self.config["batchsize"], num_channels_latents, image_latent_height, image_latent_width)
if self.cpu_offload: if self.cpu_offload:
self.model.to(torch.device("cpu")) self.model.to(torch.device("cpu"))
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -103,6 +103,28 @@ class AnimateInputInfo: ...@@ -103,6 +103,28 @@ class AnimateInputInfo:
target_shape: int = field(default_factory=int) target_shape: int = field(default_factory=int)
@dataclass
class T2IInputInfo:
seed: int = field(default_factory=int)
prompt: str = field(default_factory=str)
negative_prompt: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
# shape related
target_shape: int = field(default_factory=int)
@dataclass
class I2IInputInfo:
seed: int = field(default_factory=int)
prompt: str = field(default_factory=str)
negative_prompt: str = field(default_factory=str)
image_path: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
# shape related
target_shape: int = field(default_factory=int)
processed_image_size: int = field(default_factory=list)
def set_input_info(args): def set_input_info(args):
if args.task == "t2v": if args.task == "t2v":
input_info = T2VInputInfo( input_info = T2VInputInfo(
...@@ -161,10 +183,23 @@ def set_input_info(args): ...@@ -161,10 +183,23 @@ def set_input_info(args):
save_result_path=args.save_result_path, save_result_path=args.save_result_path,
return_result_tensor=args.return_result_tensor, return_result_tensor=args.return_result_tensor,
) )
elif args.task == "t2i":
input_info = T2IInputInfo(
seed=args.seed,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
save_result_path=args.save_result_path,
)
elif args.task == "i2i":
input_info = I2IInputInfo(
seed=args.seed,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
image_path=args.image_path,
save_result_path=args.save_result_path,
)
else: else:
raise ValueError(f"Unsupported task: {args.task}") raise ValueError(f"Unsupported task: {args.task}")
assert not (input_info.save_result_path and input_info.return_result_tensor), "save_result_path and return_result_tensor cannot be set at the same time"
return input_info return input_info
......
...@@ -34,6 +34,8 @@ python -m lightx2v.infer \ ...@@ -34,6 +34,8 @@ python -m lightx2v.infer \
--task i2i \ --task i2i \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \ --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \ --prompt "turn the style of the photo to vintage comic book" \
--image_path input.jpg \ --negative_prompt " " \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png --image_path /data/nvme2/wushuo/qwen-image/pie.png \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \
--seed 0
...@@ -34,7 +34,9 @@ python -m lightx2v.infer \ ...@@ -34,7 +34,9 @@ python -m lightx2v.infer \
--model_cls qwen_image \ --model_cls qwen_image \
--task i2i \ --task i2i \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_i2i_block.json \ --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \ --prompt "turn the style of the photo to vintage comic book" \
--image_path input.jpg \ --negative_prompt " " \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png --image_path pie.png \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i.png \
--seed 0
...@@ -34,5 +34,7 @@ python -m lightx2v.infer \ ...@@ -34,5 +34,7 @@ python -m lightx2v.infer \
--task t2i \ --task t2i \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \ --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic compositionUltra HD, 4K, cinematic composition.' \ --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.' \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png --negative_prompt " " \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png \
--seed 42
...@@ -33,6 +33,8 @@ python -m lightx2v.infer \ ...@@ -33,6 +33,8 @@ python -m lightx2v.infer \
--model_cls qwen_image \ --model_cls qwen_image \
--task t2i \ --task t2i \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_t2i_block.json \ --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic compositionUltra HD, 4K, cinematic composition.' \ --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition, Ultra HD, 4K, cinematic composition.' \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png --negative_prompt " " \
--save_result_path ${lightx2v_path}/save_results/qwen_image_t2i.png \
--seed 42
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment