Commit d91e8d68 authored by sandy's avatar sandy Committed by GitHub
Browse files

Audio/r2v 5b deb (#251)



* [Feat] audio driven for wan2.2-r2v-5b

* refactor: reduce vae args dim &&  timestep init

* update audio pre_infer

* update runner

* update scripts

* [Fix] move ref timestep padding to Scheduler

* fix wan i2v

---------
Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent d71f936d
{
"infer_steps": 4,
"target_fps": 16,
"target_fps": 24,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 121,
......@@ -20,9 +20,11 @@
"offload_granularity": "model",
"fps": 24,
"use_image_encoder": false,
"adaptive_resize": true,
"use_31_block": false,
"lora_configs": [
{
"path": "/data/nvme0/models/wan_ti2v_5b_ref/20250812/model_ema.safetensors",
"path": "/mnt/aigc/rtxiang/pretrain/qianhai_weights/lora_model.safetensors",
"strength": 0.125
}
]
......
......@@ -8,7 +8,7 @@ from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner #
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
......
......@@ -11,6 +11,7 @@ class WanAudioPostInfer(WanPostInfer):
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, x, pre_infer_out):
x = x[: pre_infer_out.seq_lens[0]]
pre_infer_out.grid_sizes[:, 0] -= 1
x = self.unpatchify(x, pre_infer_out.grid_sizes)
if self.clean_cuda_cache:
......
......@@ -24,19 +24,17 @@ class WanAudioPreInfer(WanPreInfer):
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
self.rope_t_dim = d // 2 - 2 * (d // 6)
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
if config.parallel:
self.sp_size = config.parallel.get("seq_p_size", 1)
else:
self.sp_size = 1
def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = torch.cat([self.scheduler.latents, prev_mask, prev_latents], dim=0)
hidden_states = self.scheduler.latents
if self.config.model_cls != "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0)
x = hidden_states
t = self.scheduler.timestep_input
......@@ -45,11 +43,10 @@ class WanAudioPreInfer(WanPreInfer):
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
# seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encoder_out"].to(self.scheduler.latents.dtype)
# batch_size = len(x)
num_channels, _, height, width = x.shape
ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
......@@ -60,7 +57,7 @@ class WanAudioPreInfer(WanPreInfer):
device=self.scheduler.latents.device,
)
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=0)
y = ref_image_encoder # 第一个batch维度变成list
y = ref_image_encoder
# embeddings
x = weights.patch_embedding.apply(x.unsqueeze(0))
......@@ -70,8 +67,11 @@ class WanAudioPreInfer(WanPreInfer):
y = weights.patch_embedding.apply(y.unsqueeze(0))
y = y.flatten(2).transpose(1, 2).contiguous()
x = torch.cat([x, y], dim=1).squeeze(0)
x = torch.cat([x, y], dim=1)
####for r2v # zero temporl component corresponding to ref embeddings
self.freqs[grid_sizes[0][0] :, : self.rope_t_dim] = 0
grid_sizes[:, 0] += 1
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.sensitive_layer_dtype != self.infer_dtype:
......@@ -117,7 +117,7 @@ class WanAudioPreInfer(WanPreInfer):
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
x=x.squeeze(0),
x=x,
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
......
......@@ -3,7 +3,6 @@ import torch.distributed as dist
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import get_q_lens_audio_range
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.models.networks.wan.infer.utils import compute_freqs_audio, compute_freqs_audio_dist
class WanAudioTransformerInfer(WanOffloadTransformerInfer):
......@@ -15,21 +14,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
@torch.no_grad()
def compute_freqs(self, q, grid_sizes, freqs):
if self.config["seq_parallel"]:
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.no_grad()
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
audio_grid_sizes = [row.clone() for row in pre_infer_out.grid_sizes]
audio_grid_sizes[0][0] -= 1
x = self.modify_hidden_states(
hidden_states=x.to(self.infer_dtype),
grid_sizes=pre_infer_out.grid_sizes,
grid_sizes=audio_grid_sizes,
ca_block=self.audio_adapter.ca[self.block_idx],
audio_encoder_output=pre_infer_out.adapter_output["audio_encoder_output"],
t_emb=self.scheduler.audio_adapter_t_emb,
......
......@@ -20,26 +20,6 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), # 时间(帧)编码
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), # 空间(高度)编码
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), # 空间(宽度)编码
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0 ###for r2v # zero temporl component corresponding to ref embeddings
return freqs_i
def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
......@@ -61,31 +41,6 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
return freqs_i_rank
def compute_freqs_audio_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
valid_token_length = f * h * w
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i[valid_token_length:, :, :f] = 0
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0]
......
......@@ -21,10 +21,11 @@ from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, save_to_video, vae_to_comfyui_image
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
......@@ -300,7 +301,8 @@ class WanAudioRunner(WanRunner): # type:ignore
self.vae_encoder = self.load_vae_encoder()
img = rearrange(img, "1 C H W -> 1 C 1 H W")
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float))[0]
vae_encoder_out = self.vae_encoder.encode(img.to(torch.float)).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
......@@ -339,40 +341,42 @@ class WanAudioRunner(WanRunner): # type:ignore
if prev_video is not None:
# Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
if self.config.model_cls != "wan2.2_audio":
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_len = (prev_frame_length - 1) // 4 + 1
else:
prev_len = 0
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
_, nframe, height, width = self.model.scheduler.latents.shape
if self.config.model_cls == "wan2.2_audio":
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
_, prev_mask = self._wan22_masks_like([self.model.scheduler.latents], zero=True, prev_length=prev_latents.shape[1])
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype))[0].to(dtype)
if prev_video is not None:
prev_token_length = (prev_frame_length - 1) // 4 + 1
prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
else:
prev_frame_len = 0
prev_latents = None
prev_mask = self.model.scheduler.mask
else:
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype)).to(dtype)
frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask[:, prev_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask)
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
if prev_latents is not None:
if prev_latents.shape[-2:] != (height, width):
logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
gc.collect()
return {"prev_latents": prev_latents, "prev_mask": prev_mask}
return {"prev_latents": prev_latents, "prev_mask": prev_mask, "prev_len": prev_len}
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
"""Rearrange mask for WAN model"""
......@@ -413,12 +417,11 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_features = self.audio_adapter.forward_audio_proj(audio_features, self.model.scheduler.latents.shape[1])
self.inputs["audio_encoder_output"] = audio_features
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
# Reset scheduler for non-first segments
if segment_idx > 0:
self.model.scheduler.reset()
self.inputs["previmg_encoder_output"] = self.prepare_prev_latents(self.prev_video, prev_frame_length=5)
self.model.scheduler.reset(self.inputs["previmg_encoder_output"])
@ProfilingContext4Debug("End run segment")
def end_run_segment(self):
......@@ -600,3 +603,48 @@ class WanAudioRunner(WanRunner): # type:ignore
ret["target_shape"] = self.config.target_shape
return ret
@RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner):
def __init__(self, config):
super().__init__(config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
vae_decoder = Wan2_2_VAE(**vae_config)
return vae_decoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
"device": vae_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
return None
else:
return Wan2_2_VAE(**vae_config)
def load_vae(self):
vae_encoder = self.load_vae_encoder()
vae_decoder = self.load_vae_decoder()
return vae_encoder, vae_decoder
......@@ -309,7 +309,7 @@ class WanRunner(DefaultRunner):
dim=1,
).cuda()
vae_encoder_out = self.vae_encoder.encode([vae_input])[0]
vae_encoder_out = self.vae_encoder.encode(vae_input.unsqueeze(0))
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
......
import gc
import math
import numpy as np
import torch
......@@ -6,12 +7,22 @@ from loguru import logger
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.utils import masks_like
class ConsistencyModelScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
if self.config.parallel:
self.sp_size = self.config.parallel.get("seq_p_size", 1)
else:
self.sp_size = 1
if self.config["model_cls"] == "wan2.2_audio":
self.prev_latents = None
self.prev_len = 0
def set_audio_adapter(self, audio_adapter):
self.audio_adapter = audio_adapter
......@@ -23,7 +34,45 @@ class ConsistencyModelScheduler(WanScheduler):
if self.audio_adapter.cpu_offload:
self.audio_adapter.time_embedding.to("cpu")
def prepare(self, image_encoder_output=None):
if self.config.model_cls == "wan2.2_audio":
_, lat_f, lat_h, lat_w = self.latents.shape
F = (lat_f - 1) * self.config.vae_stride[0] + 1
per_latent_token_len = lat_h * lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
max_seq_len = ((F - 1) // self.config.vae_stride[0] + 1) * per_latent_token_len
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
self.timestep_input = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * self.timestep_input]).unsqueeze(0)
self.timestep_input = torch.cat(
[
self.timestep_input,
torch.zeros(
(1, per_latent_token_len), # padding for reference frame latent
dtype=self.timestep_input.dtype,
device=self.timestep_input.device,
),
],
dim=1,
)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2_audio":
self.mask = masks_like(self.latents, zero=True, prev_len=self.prev_len)
if self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def prepare(self, previmg_encoder_output=None):
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
......@@ -43,8 +92,13 @@ class ConsistencyModelScheduler(WanScheduler):
x0 = sample - model_output * sigma
x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
self.latents = x_t_next
if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
def reset(self):
def reset(self, previmg_encoder_output=None):
if self.config["model_cls"] == "wan2.2_audio":
self.prev_latents = previmg_encoder_output["prev_latents"]
self.prev_len = previmg_encoder_output["prev_len"]
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
......
......@@ -892,17 +892,17 @@ class WanVAE:
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def encode(self, videos):
def encode(self, video):
"""
videos: A list of videos each with shape [C, T, H, W].
video: one video with shape [1, C, T, H, W].
"""
if self.cpu_offload:
self.to_cuda()
if self.use_tiling:
out = [self.model.tiled_encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos]
out = self.model.tiled_encode(video, self.scale).float().squeeze(0)
else:
out = [self.model.encode(u.unsqueeze(0).to(self.current_device()), self.scale).float().squeeze(0) for u in videos]
out = self.model.encode(video, self.scale).float().squeeze(0)
if self.cpu_offload:
self.to_cpu()
......
......@@ -985,10 +985,10 @@ class Wan2_2_VAE:
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def encode(self, videos):
def encode(self, video):
if self.cpu_offload:
self.to_cuda()
out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0)
out = self.model.encode(video, self.scale).float().squeeze(0)
if self.cpu_offload:
self.to_cpu()
return out
......
......@@ -434,17 +434,16 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
return distributed_weight_dict
def masks_like(tensor, zero=False, generator=None, p=0.2):
def masks_like(tensor, zero=False, generator=None, p=0.2, prev_len=1):
assert isinstance(tensor, torch.Tensor)
out = torch.ones_like(tensor)
if zero:
if generator is not None:
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
out[:, 0] = torch.zeros_like(out[:, 0])
out[:, :prev_len] = torch.zeros_like(out[:, :prev_len])
else:
out[:, 0] = torch.zeros_like(out[:, 0])
out[:, :prev_len] = torch.zeros_like(out[:, :prev_len])
return out
......
......@@ -9,6 +9,8 @@ export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls wan2.2_audio \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment