Commit a40ffb3f authored by Watebear's avatar Watebear Committed by GitHub
Browse files

refactor qwen-image (#297)

parent 701075f4
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
MM_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER,
)
class QwenImagePreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
# img_in
self.add_module(
"img_in",
MM_WEIGHT_REGISTER["Default"]("img_in.weight", "img_in.bias"),
)
# txt_in
self.add_module(
"txt_in",
MM_WEIGHT_REGISTER["Default"]("txt_in.weight", "txt_in.bias"),
)
# txt_norm
self.add_module("txt_norm", RMS_WEIGHT_REGISTER["fp32_variance"]("txt_norm.weight"))
# time_text_embed
self.add_module(
"time_text_embed_timestep_embedder_linear_1", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_1.weight", "time_text_embed.timestep_embedder.linear_1.bias")
)
self.add_module(
"time_text_embed_timestep_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias")
)
......@@ -69,14 +69,14 @@ class QwenImageRunner(DefaultRunner):
else:
assert NotImplementedError
self.model.set_scheduler(self.scheduler)
@ProfilingContext4DebugL2("Run DiT")
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run(total_steps)
self.end_run()
return latents, generator
......@@ -167,11 +167,7 @@ class QwenImageRunner(DefaultRunner):
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width)
def init_scheduler(self):
scheduler = QwenImageScheduler(self.config)
self.model.set_scheduler(scheduler)
self.model.pre_infer.set_scheduler(scheduler)
self.model.transformer_infer.set_scheduler(scheduler)
self.model.post_infer.set_scheduler(scheduler)
self.scheduler = QwenImageScheduler(self.config)
def get_encoder_output_i2v(self):
pass
......
import gc
import json
import os
from typing import Optional
......@@ -27,15 +28,23 @@ def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Gen
class AutoencoderKLQwenImageVAE:
def __init__(self, config):
self.config = config
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(config.model_path, "vae")).to(torch.device("cuda")).to(torch.bfloat16)
self.image_processor = VaeImageProcessor(vae_scale_factor=config.vae_scale_factor * 2)
with open(os.path.join(config.model_path, "vae", "config.json"), "r") as f:
vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.dtype = torch.bfloat16
self.cpu_offload = config.get("cpu_offload", False)
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.latent_channels = config.vae_z_dim
self.load()
def load(self):
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(self.config.model_path, "vae")).to(self.device).to(self.dtype)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.config.vae_scale_factor * 2)
with open(os.path.join(self.config.model_path, "vae", "config.json"), "r") as f:
vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.generator = torch.Generator(device="cuda").manual_seed(self.config.seed)
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
......@@ -55,6 +64,8 @@ class AutoencoderKLQwenImageVAE:
@torch.no_grad()
def decode(self, latents):
if self.cpu_offload:
self.model.to(torch.device("cuda"))
if self.config.task == "t2i":
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
elif self.config.task == "i2i":
......@@ -66,6 +77,10 @@ class AutoencoderKLQwenImageVAE:
latents = latents / latents_std + latents_mean
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil")
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch.cuda.empty_cache()
gc.collect()
return images
@staticmethod
......@@ -88,9 +103,12 @@ class AutoencoderKLQwenImageVAE:
return image_latents
@torch.no_grad()
def encode_vae_image(self, image):
if self.cpu_offload:
self.model.to(torch.device("cuda"))
num_channels_latents = self.config.transformer_in_channels // 4
image = image.to(self.device).to(self.dtype)
image = image.to(self.model.device).to(self.dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=self.generator)
else:
......@@ -106,4 +124,8 @@ class AutoencoderKLQwenImageVAE:
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(image_latents, self.config.batchsize, num_channels_latents, image_latent_height, image_latent_width)
if self.cpu_offload:
self.model.to(torch.device("cpu"))
torch.cuda.empty_cache()
gc.collect()
return image_latents
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/offload/block/qwen_image_i2i_block.json \
--prompt "Change the rabbit's color to purple, with a flash light background." \
--image_path input.jpg \
--save_video_path ${lightx2v_path}/save_results/qwen_image_i2i.png
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment