"next_docs/vscode:/vscode.git/clone" did not exist on "7dc3b0a9a2df3281f42abbf69ba920367f94a750"
Commit cf04772a authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Fix wan22 ti2v vae & update audio profiler (#246)

parent cb83f2f8
......@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner #
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
......
import glob
import os
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
......@@ -29,13 +26,3 @@ class WanAudioModel(WanModel):
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
self.transformer_infer.set_audio_adapter(self.audio_adapter)
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
......@@ -11,6 +11,7 @@ class BaseRunner(ABC):
def __init__(self, config):
self.config = config
self.vae_encoder_need_img_original = False
def load_transformer(self):
"""Load transformer model
......
......@@ -145,16 +145,16 @@ class DefaultRunner(BaseRunner):
gc.collect()
def read_image_input(self, img_path):
img = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img
img_ori = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img, img_ori
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = self.read_image_input(self.config["image_path"])
img, img_ori = self.read_image_input(self.config["image_path"])
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img)
vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache()
gc.collect()
......
......@@ -14,19 +14,17 @@ from einops import rearrange
from loguru import logger
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from transformers import AutoFeatureExtractor
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image
from lightx2v.utils.utils import load_weights, save_to_video, vae_to_comfyui_image
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
......@@ -398,6 +396,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.cut_audio_list = []
self.prev_video = None
@ProfilingContext4Debug("Init run segment")
def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx
......@@ -421,6 +420,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
@ProfilingContext4Debug("End run segment")
def end_run_segment(self):
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
......@@ -446,6 +446,7 @@ class WanAudioRunner(WanRunner): # type:ignore
del self.gen_video
torch.cuda.empty_cache()
@ProfilingContext4Debug("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True):
# Merge results
gen_lvideo = torch.cat(self.gen_video_list, dim=2).float()
......@@ -599,89 +600,3 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape
return ret
@RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
return None
else:
return Wan2_2_VAE(**vae_config)
def load_vae(self):
vae_encoder = self.load_vae_encoder()
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = Wan22MoeAudioModel(
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
)
low_noise_model = Wan22MoeAudioModel(
os.path.join(self.config.model_path, "low_noise_model"),
self.config,
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
if lora_config.name == "high_noise_model":
lora_wrapper = WanLoraWrapper(high_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
if lora_config.name == "low_noise_model":
lora_wrapper = WanLoraWrapper(low_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
# XXX: trick
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
......@@ -430,6 +430,7 @@ class Wan22MoeRunner(WanRunner):
class Wan22DenseRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.vae_encoder_need_img_original = True
def load_vae_decoder(self):
# offload config
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python -m lightx2v.infer \
--model_cls wan2.2_moe_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_moe_i2v_audio.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment