Commit cf04772a authored by Yang Yong(雍洋)'s avatar Yang Yong(雍洋) Committed by GitHub
Browse files

Fix wan22 ti2v vae & update audio profiler (#246)

parent cb83f2f8
...@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # ...@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner #
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
......
import glob
import os
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer from lightx2v.models.networks.wan.infer.audio.pre_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer from lightx2v.models.networks.wan.infer.audio.transformer_infer import WanAudioTransformerInfer
...@@ -29,13 +26,3 @@ class WanAudioModel(WanModel): ...@@ -29,13 +26,3 @@ class WanAudioModel(WanModel):
def set_audio_adapter(self, audio_adapter): def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter self.audio_adapter = audio_adapter
self.transformer_infer.set_audio_adapter(self.audio_adapter) self.transformer_infer.set_audio_adapter(self.audio_adapter)
class Wan22MoeAudioModel(WanAudioModel):
def _load_ckpt(self, unified_dtype, sensitive_layer):
safetensors_files = glob.glob(os.path.join(self.model_path, "*.safetensors"))
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
...@@ -11,6 +11,7 @@ class BaseRunner(ABC): ...@@ -11,6 +11,7 @@ class BaseRunner(ABC):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.vae_encoder_need_img_original = False
def load_transformer(self): def load_transformer(self):
"""Load transformer model """Load transformer model
......
...@@ -145,16 +145,16 @@ class DefaultRunner(BaseRunner): ...@@ -145,16 +145,16 @@ class DefaultRunner(BaseRunner):
gc.collect() gc.collect()
def read_image_input(self, img_path): def read_image_input(self, img_path):
img = Image.open(img_path).convert("RGB") img_ori = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img return img, img_ori
@ProfilingContext("Run Encoders") @ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self): def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"] prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
img = self.read_image_input(self.config["image_path"]) img, img_ori = self.read_image_input(self.config["image_path"])
clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(img) vae_encode_out = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
text_encoder_output = self.run_text_encoder(prompt, img) text_encoder_output = self.run_text_encoder(prompt, img)
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
...@@ -14,19 +14,17 @@ from einops import rearrange ...@@ -14,19 +14,17 @@ from einops import rearrange
from loguru import logger from loguru import logger
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
from transformers import AutoFeatureExtractor
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import load_weights, save_to_video, vae_to_comfyui_image
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size): def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
...@@ -398,6 +396,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -398,6 +396,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.cut_audio_list = [] self.cut_audio_list = []
self.prev_video = None self.prev_video = None
@ProfilingContext4Debug("Init run segment")
def init_run_segment(self, segment_idx): def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx self.segment_idx = segment_idx
...@@ -421,6 +420,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -421,6 +420,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5) self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
@ProfilingContext4Debug("End run segment")
def end_run_segment(self): def end_run_segment(self):
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float) self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
...@@ -446,6 +446,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -446,6 +446,7 @@ class WanAudioRunner(WanRunner): # type:ignore
del self.gen_video del self.gen_video
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ProfilingContext4Debug("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True): def process_images_after_vae_decoder(self, save_video=True):
# Merge results # Merge results
gen_lvideo = torch.cat(self.gen_video_list, dim=2).float() gen_lvideo = torch.cat(self.gen_video_list, dim=2).float()
...@@ -599,89 +600,3 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -599,89 +600,3 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
@RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
return None
else:
return Wan2_2_VAE(**vae_config)
def load_vae(self):
vae_encoder = self.load_vae_encoder()
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = Wan22MoeAudioModel(
os.path.join(self.config.model_path, "high_noise_model"),
self.config,
self.init_device,
)
low_noise_model = Wan22MoeAudioModel(
os.path.join(self.config.model_path, "low_noise_model"),
self.config,
self.init_device,
)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
if lora_config.name == "high_noise_model":
lora_wrapper = WanLoraWrapper(high_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
if lora_config.name == "low_noise_model":
lora_wrapper = WanLoraWrapper(low_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
# XXX: trick
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)
...@@ -430,6 +430,7 @@ class Wan22MoeRunner(WanRunner): ...@@ -430,6 +430,7 @@ class Wan22MoeRunner(WanRunner):
class Wan22DenseRunner(WanRunner): class Wan22DenseRunner(WanRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.vae_encoder_need_img_original = True
def load_vae_decoder(self): def load_vae_decoder(self):
# offload config # offload config
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python -m lightx2v.infer \
--model_cls wan2.2_moe_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_moe_i2v_audio.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan22_moe_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment