"tests/vscode:/vscode.git/clone" did not exist on "28972b8667653c10459544fa4d05551721d5b2c2"
Commit a40ffb3f authored by Watebear's avatar Watebear Committed by GitHub
Browse files

refactor qwen-image (#297)

parent 701075f4
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER,
)
class QwenImagePreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
# img_in
self.add_module(
"img_in",
MM_WEIGHT_REGISTER["Default"]("img_in.weight", "img_in.bias"),
)
# txt_in
self.add_module(
"txt_in",
MM_WEIGHT_REGISTER["Default"]("txt_in.weight", "txt_in.bias"),
)
# txt_norm
self.add_module("txt_norm", RMS_WEIGHT_REGISTER["fp32_variance"]("txt_norm.weight"))
# time_text_embed
self.add_module(
"time_text_embed_timestep_embedder_linear_1", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_1.weight", "time_text_embed.timestep_embedder.linear_1.bias")
)
self.add_module(
"time_text_embed_timestep_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias")
)
import os
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
class QwenImageTransformerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.blocks_num = config["num_layers"]
self.task = config["task"]
self.config = config
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num))
self.add_module("blocks", blocks)
class QwenImageTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config, block_prefix="transformer_blocks"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
# Image processing modules
self.add_module(
"img_mod",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mod.1.weight",
f"{block_prefix}.{self.block_index}.img_mod.1.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_norm1",
LN_WEIGHT_REGISTER["Default"](eps=1e-6),
)
self.attn = QwenImageCrossAttention(
block_index=block_index, block_prefix="transformer_blocks", task=config.task, mm_type=mm_type, config=config, lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file
)
self.add_module("attn", self.attn)
self.add_module(
"img_norm2",
LN_WEIGHT_REGISTER["Default"](eps=1e-6),
)
img_mlp = QwenImageFFN(
block_index=block_index,
block_prefix="transformer_blocks",
ffn_prefix="img_mlp",
task=config.task,
mm_type=mm_type,
config=config,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
self.add_module("img_mlp", img_mlp)
# Text processing modules
self.add_module(
"txt_mod",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mod.1.weight",
f"{block_prefix}.{self.block_index}.txt_mod.1.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_norm1",
LN_WEIGHT_REGISTER["Default"](eps=1e-6),
)
# Text doesn't need separate attention - it's handled by img_attn joint computation
self.add_module(
"txt_norm2",
LN_WEIGHT_REGISTER["Default"](eps=1e-6),
)
txt_mlp = QwenImageFFN(
block_index=block_index,
block_prefix="transformer_blocks",
ffn_prefix="txt_mlp",
task=config.task,
mm_type=mm_type,
config=config,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
self.add_module("txt_mlp", txt_mlp)
self.cpu_offload = config["cpu_offload"]
if self.cpu_offload:
offload_granularity = config.get("offload_granularity", "block")
if offload_granularity == "phase":
phase1_dict = {
"img_mod": self.img_mod,
"txt_mod": self.txt_mod,
"img_norm1": self.img_norm1,
"txt_norm1": self.txt_norm1,
}
phase2_dict = {"attn": self.attn}
phase3_dict = {
"img_norm2": self.img_norm2,
"img_mlp": self.img_mlp,
"txt_norm2": self.txt_norm2,
"txt_mlp": self.txt_mlp,
}
compute_phases = [
ComputePhase(phase1_dict),
ComputePhase(phase2_dict),
ComputePhase(phase3_dict),
]
self.add_module("compute_phases", compute_phases)
class QwenImageCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.sparge = config.get("sparge", False)
self.attn_type = config.get("attn_type", "flash_attn3")
self.heads = config["attention_out_dim"] // config["attention_dim_head"]
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
# norm_q
self.add_module(
"norm_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight"),
)
# norm_k
self.add_module(
"norm_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight"),
)
# to_q
self.add_module(
"to_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_q.weight",
f"{block_prefix}.{self.block_index}.attn.to_q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# to_k
self.add_module(
"to_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_k.weight",
f"{block_prefix}.{self.block_index}.attn.to_k.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# to_v
self.add_module(
"to_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_v.weight",
f"{block_prefix}.{self.block_index}.attn.to_v.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# add_q_proj
self.add_module(
"add_q_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# add_k_proj
self.add_module(
"add_k_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# add_v_proj
self.add_module(
"add_v_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# to_out
self.add_module(
"to_out",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_out.0.weight",
f"{block_prefix}.{self.block_index}.attn.to_out.0.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# to_add_out
self.add_module(
"to_add_out",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_add_out.weight",
f"{block_prefix}.{self.block_index}.attn.to_add_out.bias",
self.lazy_load,
self.lazy_load_file,
),
)
# norm_added_q
self.add_module(
"norm_added_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight"),
)
# norm_added_k
self.add_module(
"norm_added_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight"),
)
# attn
self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]())
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
class QwenImageFFN(WeightModule):
def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"mlp_0",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"mlp_2",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias",
self.lazy_load,
self.lazy_load_file,
),
)
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
class ComputePhase(WeightModule):
def __init__(self, sub_module_dict):
super().__init__()
for k, v in sub_module_dict.items():
self.add_module(k, v)
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
......@@ -69,14 +69,14 @@ class QwenImageRunner(DefaultRunner):
else:
assert NotImplementedError
self.model.set_scheduler(self.scheduler)
@ProfilingContext4DebugL2("Run DiT")
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run(total_steps)
self.end_run()
return latents, generator
......@@ -167,11 +167,7 @@ class QwenImageRunner(DefaultRunner):
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width)
def init_scheduler(self):
scheduler = QwenImageScheduler(self.config)
self.model.set_scheduler(scheduler)
self.model.pre_infer.set_scheduler(scheduler)
self.model.transformer_infer.set_scheduler(scheduler)
self.model.post_infer.set_scheduler(scheduler)
self.scheduler = QwenImageScheduler(self.config)
def get_encoder_output_i2v(self):
pass
......
import gc
import json
import os
from typing import Optional
......@@ -27,15 +28,23 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Gen
class AutoencoderKLQwenImageVAE:
def __init__(self, config):
self.config = config
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(config.model_path, "vae")).to(torch.device("cuda")).to(torch.bfloat16)
self.image_processor = VaeImageProcessor(vae_scale_factor=config.vae_scale_factor * 2)
with open(os.path.join(config.model_path, "vae", "config.json"), "r") as f:
vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.device = torch.device("cuda")
self.latent_channels = config.vae_z_dim
self.load()
def load(self):
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config.model_path, "vae")).to(self.device).to(self.dtype)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config.vae_scale_factor * 2)
with open(os.path.join(self.config.model_path, "vae", "config.json"), "r") as f:
vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(self.config.seed)
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
......@@ -55,6 +64,8 @@ class AutoencoderKLQwenImageVAE:
@torch.no_grad()
def decode(self, latents):
if self.cpu_offload:
self.model.to(torch.device("cuda"))
if self.config.task == "t2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
elif self.config.task == "i2i":
......@@ -66,6 +77,10 @@ class AutoencoderKLQwenImageVAE:
latents = latents / latents_std + latents_mean
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil")
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch.cuda.empty_cache()
gc.collect()
return images
@staticmethod
......@@ -88,9 +103,12 @@ class AutoencoderKLQwenImageVAE:
return image_latents
@torch.no_grad()
def encode_vae_image(self, image):
if self.cpu_offload:
self.model.to(torch.device("cuda"))
num_channels_latents = self.config.transformer_in_channels // 4
image = image.to(self.device).to(self.dtype)
image = image.to(self.model.device).to(self.dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=self.generator)
else:
......@@ -106,4 +124,8 @@ class AutoencoderKLQwenImageVAE:
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(image_latents, self.config.batchsize, num_channels_latents, image_latent_height, image_latent_width)
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch.cuda.empty_cache()
gc.collect()
return image_latents
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_i2i_block.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \
--image_path input.jpg \
--save_video_path ${lightx2v_path}/save_results/qwen_image_i2i.png
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment