Commit 8bdefedf authored by wangshankun's avatar wangshankun
Browse files

add ti2v audio

parent 7516ad2a
{
"infer_steps": 4,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 121,
"text_len": 512,
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 1.0,
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": false,
"offload_granularity": "model",
"fps": 24,
"use_image_encoder": false,
"lora_configs": [
{
"path": "/data/nvme0/models/wan_ti2v_5b_ref/20250812/model_ema.safetensors",
"strength": 0.125
}
]
}
......@@ -8,7 +8,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22MoeAudioRunner, WanAudioRunner, Wan22AudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
......@@ -39,7 +39,7 @@ def main():
"--model_cls",
type=str,
required=True,
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2", "wan2.2_moe_distill"],
choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio", "wan2.2_moe", "wan2.2_moe_audio", "wan2.2_audio", "wan2.2", "wan2.2_moe_distill"],
default="wan2.1",
)
......
......@@ -3,8 +3,8 @@ import torch
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
from ..utils import rope_params, sinusoidal_embedding_1d
from ..utils import rope_params, sinusoidal_embedding_1d, masks_like
from loguru import logger
class WanAudioPreInfer(WanPreInfer):
def __init__(self, config):
......@@ -28,12 +28,18 @@ class WanAudioPreInfer(WanPreInfer):
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def infer(self, weights, inputs, positive):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1)
hidden_states = hidden_states.squeeze(0)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents
mask1, mask2 = masks_like([hidden_states], zero=True, prev_length=hidden_states.shape[1])
hidden_states = (1. - mask2[0]) * prev_latents + mask2[0] * hidden_states
else:
prev_latents = prev_latents.unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=1)
hidden_states = hidden_states.squeeze(0)
x = [hidden_states]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
......@@ -46,7 +52,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
##audio_dit_blocks = None##Debug Drop Audio
audio_dit_blocks = None##Debug Drop Audio
if positive:
context = inputs["text_encoder_output"]["context"]
......@@ -55,11 +61,11 @@ class WanAudioPreInfer(WanPreInfer):
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
batch_size = len(x)
num_channels, _, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
......@@ -77,6 +83,7 @@ class WanAudioPreInfer(WanPreInfer):
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
......
......@@ -3,6 +3,38 @@ import torch
from lightx2v.utils.envs import *
def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if prev_length == 0:
return out1, out2
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(
1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, :prev_length] = torch.normal(
mean=-3.5,
std=0.5,
size=(1,),
device=u.device,
generator=generator).expand_as(u[:, :prev_length]).exp()
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
else:
u[:, :prev_length] = u[:, :prev_length]
v[:, :prev_length] = v[:, :prev_length]
else:
for u, v in zip(out1, out2):
u[:, :prev_length] = torch.zeros_like(u[:, :prev_length])
v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
return out1, out2
def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
......
......@@ -231,9 +231,6 @@ class WanModel:
self.post_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
def _load_weights_distribute(self, weight_dict, is_weight_loader):
global_src_rank = 0
target_device = "cpu" if self.cpu_offload else "cuda"
......
......@@ -24,8 +24,8 @@ from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelSched
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image
from lightx2v.utils.utils import save_to_video, vae_to_comfyui_image, find_torch_model_path
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
@contextmanager
def memory_efficient_inference():
......@@ -322,7 +322,11 @@ class VideoGenerator:
if segment_idx == 0:
# First segment - create zero frames
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
if self.config.model_cls == 'wan2.2_audio':
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
else:
# Subsequent segments - use previous video
......@@ -333,7 +337,10 @@ class VideoGenerator:
else:
# Fallback to zeros if prepare_prev_latents fails
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
if self.config.model_cls == 'wan2.2_audio':
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
# Create mask for prev_latents
......@@ -613,7 +620,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group)
logger.info(f"Loaded base model: {self.config.model_path}")
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
......@@ -673,8 +680,12 @@ class WanAudioRunner(WanRunner): # type:ignore
# vae encode
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
vae_encoder_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encoder_out, list):
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
if self.config.model_cls == "wan2.2_audio":
vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE())
else:
if isinstance(vae_encoder_out, list):
vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
return vae_encoder_out, clip_encoder_out
......@@ -682,6 +693,9 @@ class WanAudioRunner(WanRunner): # type:ignore
"""Set target shape for generation"""
ret = {}
num_channels_latents = 16
if self.config.model_cls == "wan2.2_audio":
num_channels_latents = self.config.num_channels_latents
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
......@@ -755,6 +769,50 @@ class WanAudioRunner(WanRunner): # type:ignore
self.end_run()
@RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
return None
else:
return Wan2_2_VAE(**vae_config)
def load_vae(self):
vae_encoder = self.load_vae_encoder()
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
def __init__(self, config):
......
......@@ -7,6 +7,7 @@ import torch.nn.functional as F
from einops import rearrange
from lightx2v.utils.utils import load_weights
from loguru import logger
__all__ = [
"Wan2_2_VAE",
......@@ -256,6 +257,10 @@ class AttentionBlock(nn.Module):
def patchify(x, patch_size):
if patch_size == 1:
return x
if x.dim() == 6:
x = x.squeeze(0)
if x.dim() == 4:
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
elif x.dim() == 5:
......@@ -828,7 +833,7 @@ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", cpu_offloa
# load checkpoint
logging.info(f"loading {pretrained_path}")
weights_dict = load_weights(pretrained_path, cpu_offload=cpu_offload)
model.load_state_dict(weights_dict)
model.load_state_dict(weights_dict, assign=True)
return model
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.2_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/audio_driven/wan22_ti2v_i2v_audio.json \
--prompt "The video features a old lady is saying something and knitting a sweater." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment