Commit e58dd9fe authored by wangshankun's avatar wangshankun
Browse files

audio驱动wan视频生成

parent 7260cb2e
{
"infer_steps": 20,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":5.0,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
}
...@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner ...@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
...@@ -41,14 +42,16 @@ def init_runner(config): ...@@ -41,14 +42,16 @@ def init_runner(config):
async def main(): async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true") parser.add_argument("--use_prompt_enhancer", action="store_true")
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
......
import flash_attn
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
from loguru import logger
import pdb
import os
import safetensors
from typing import List, Optional, Tuple, Union
def load_safetensors(in_path: str):
if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path)
elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path)
else:
raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str):
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def load_safetensors_from_dir(in_dir: str):
tensors = {}
safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
return tensors
def load_pt_safetensors(in_path: str):
ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
else:
state_dict = load_safetensors(in_path)
return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
import torch.distributed as dist
if (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized()):
state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized():
dist.barrier()
return model.to(dtype=torch.bfloat16, device="cuda")
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear")
return output_features.transpose(1, 2)
def get_q_lens_audio_range(
batchsize,
n_tokens_per_rank,
n_query_tokens,
n_tokens_per_frame,
sp_rank,
):
if n_query_tokens == 0:
q_lens = [1] * batchsize
return q_lens, 0, 1
idx0 = n_tokens_per_rank * sp_rank
first_length = idx0 - idx0 // n_tokens_per_frame * n_tokens_per_frame
n_frames = (n_query_tokens - first_length) // n_tokens_per_frame
last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length
q_lens = []
if first_length > 0:
q_lens.append(first_length)
q_lens += [n_tokens_per_frame] * n_frames
if last_length > 0:
q_lens.append(last_length)
t0 = idx0 // n_tokens_per_frame
idx1 = idx0 + n_query_tokens
t1 = math.ceil(idx1 / n_tokens_per_frame)
return q_lens * batchsize, t0, t1
class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False):
super().__init__()
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
kv_dim = inner_dim if kv_dim is None else kv_dim
self.norm_kv = nn.LayerNorm(kv_dim)
self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN)
self.to_q = nn.Linear(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, inner_dim)
if adaLN:
self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5)
else:
shift_scale_gate = torch.zeros((1, 3, inner_dim))
shift_scale_gate[:, 2] = 1
self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)
def forward(self, x, latents, t_emb, q_lens, k_lens):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
batchsize = len(x)
x = self.norm_kv(x)
shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1)
latents = self.norm_q(latents) * (1 + scale) + shift
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads)
out = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=q_lens.max(),
max_seqlen_k=k_lens.max(),
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
)
out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize)
return self.to_out(out) * gate
class AudioProjection(nn.Module):
def __init__(
self,
audio_feature_dim: int = 768,
n_neighbors: tuple = (2, 2),
num_tokens: int = 32,
mlp_dims: tuple = (1024, 1024, 32 * 768),
transformer_layers: int = 4,
):
super().__init__()
mlp = []
self.left, self.right = n_neighbors
self.audio_frames = sum(n_neighbors) + 1
in_dim = audio_feature_dim * self.audio_frames
for i, out_dim in enumerate(mlp_dims):
mlp.append(nn.Linear(in_dim, out_dim))
if i != len(mlp_dims) - 1:
mlp.append(nn.ReLU())
in_dim = out_dim
self.mlp = nn.Sequential(*mlp)
self.norm = nn.LayerNorm(mlp_dims[-1] // num_tokens)
self.num_tokens = num_tokens
if transformer_layers > 0:
decoder_layer = nn.TransformerDecoderLayer(d_model=audio_feature_dim, nhead=audio_feature_dim // 64, dim_feedforward=4 * audio_feature_dim, dropout=0.0, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=transformer_layers,
)
else:
self.transformer_decoder = None
def forward(self, audio_feature, latent_frame):
video_frame = (latent_frame - 1) * 4 + 1
audio_feature_ori = audio_feature
audio_feature = linear_interpolation(audio_feature_ori, video_frame)
if self.transformer_decoder is not None:
audio_feature = self.transformer_decoder(audio_feature, audio_feature_ori)
audio_feature = F.pad(audio_feature, pad=(0, 0, self.left, self.right), mode="replicate")
audio_feature = audio_feature.unfold(dimension=1, size=self.audio_frames, step=1)
audio_feature = rearrange(audio_feature, "B T C W -> B T (W C)")
audio_feature = self.mlp(audio_feature) # (B, video_frame, C)
audio_feature = rearrange(audio_feature, "B T (N C) -> B T N C", N=self.num_tokens) # (B, video_frame, num_tokens, C)
return self.norm(audio_feature)
class TimeEmbedding(nn.Module):
def __init__(self, dim, time_freq_dim, time_proj_dim):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
def forward(
self,
timestep: torch.Tensor,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep)
timestep_proj = self.time_proj(self.act_fn(temb))
return timestep_proj
class AudioAdapter(nn.Module):
def __init__(
self,
attention_head_dim=64,
num_attention_heads=40,
base_num_layers=30,
interval=1,
audio_feature_dim: int = 768,
num_tokens: int = 32,
mlp_dims: tuple = (1024, 1024, 32 * 768),
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
super().__init__()
self.audio_proj = AudioProjection(
audio_feature_dim=audio_feature_dim,
n_neighbors=(2, 2),
num_tokens=num_tokens,
mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers,
)
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
ca_num = math.ceil(base_num_layers / interval)
self.base_num_layers = base_num_layers
self.interval = interval
self.ca = nn.ModuleList(
[
PerceiverAttentionCA(
dim_head=attention_head_dim,
heads=num_attention_heads,
kv_dim=mlp_dims[-1] // num_tokens,
adaLN=time_freq_dim > 0,
)
for _ in range(ca_num)
]
)
self.dim = attention_head_dim * num_attention_heads
if time_freq_dim > 0:
self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3)
else:
self.time_embedding = None
def rearange_audio_features(self, audio_feature: torch.Tensor):
# audio_feature (B, video_frame, num_tokens, C)
audio_feature_0 = audio_feature[:, :1]
audio_feature_0 = torch.repeat_interleave(audio_feature_0, repeats=4, dim=1)
audio_feature = torch.cat([audio_feature_0, audio_feature[:, 1:]], dim=1) # (B, 4 * latent_frame, num_tokens, C)
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
# print(weight)
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
tail_length = n_tokens_per_rank - n_tokens
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=0)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
x = x[:, t0:t1]
x = x.to(dtype)
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe
if self.time_embedding is not None:
t_emb = self.time_embedding(timestep).unflatten(1, (3, -1))
else:
t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype)
ret_dict = {}
for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)):
block_dict = {
"kwargs": {
"ca_block": self.ca[block_idx],
"x": x,
"weight": weight,
"t_emb": t_emb,
"dtype": x.dtype,
},
"modify_func": modify_hidden_states,
}
ret_dict[base_idx] = block_dict
return ret_dict
@classmethod
def from_transformer(
cls,
transformer,
audio_feature_dim: int = 1024,
interval: int = 1,
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
num_attention_heads = transformer.config["num_heads"]
base_num_layers = transformer.config["num_layers"]
attention_head_dim = transformer.config["dim"] // num_attention_heads
audio_adapter = AudioAdapter(
attention_head_dim,
num_attention_heads,
base_num_layers,
interval=interval,
audio_feature_dim=audio_feature_dim,
time_freq_dim=time_freq_dim,
projection_transformer_layers=projection_transformer_layers,
mlp_dims=(1024, 1024, 32 * audio_feature_dim),
)
return audio_adapter
def get_fsdp_wrap_module_list(
self,
):
ret_list = list(self.ca)
return ret_list
def enable_gradient_checkpointing(
self,
):
pass
class AudioAdapterPipe:
def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", generator=None, tgt_fps: int = 15, weight: float = 1.0
) -> None:
self.audio_adapter = audio_adapter
self.dtype = dtype
self.device = device
self.generator = generator
self.audio_encoder_dtype = torch.float16
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
self.audio_encoder.eval()
self.audio_encoder.to(device, self.audio_encoder_dtype)
self.tgt_fps = tgt_fps
self.weight = weight
if "base" in audio_encoder_repo:
self.audio_feature_dim = 768
else:
self.audio_feature_dim = 1024
def update_model(self, audio_adapter):
self.audio_adapter = audio_adapter
def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None):
# audio_input_feat is from AudioPreprocessor
latent_frame = latent_shape[2]
if len(audio_input_feat.shape) == 1: # 扩展batchsize = 1
audio_input_feat = audio_input_feat.unsqueeze(0)
latent_frame = latent_shape[1]
video_frame = (latent_frame - 1) * 4 + 1
audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad():
audio_input_feat = audio_input_feat.to(self.device, self.audio_encoder_dtype)
try:
audio_feat = self.audio_encoder(audio_input_feat, return_dict=True).last_hidden_state
except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to(self.device)
print(err)
audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight)
import os
import torch
import time
import glob
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
)
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
import math
import torch
import torch.cuda.amp as amp
from loguru import logger
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
class WanAudioPostInfer(WanPostInfer):
def __init__(self, config):
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, x, e, grid_sizes, valid_patch_length):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out)
x = x[:, :valid_patch_length]
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
import torch
import math
from .utils import rope_params, sinusoidal_embedding_1d
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from loguru import logger
class WanAudioPreInfer(WanPreInfer):
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"]
self.task = config["task"]
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).cuda()
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
def infer(self, weights, inputs, positive):
ltnt_channel = self.scheduler.latents.size(0)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
hidden_states = hidden_states.squeeze(0)
x = [hidden_states]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
audio_dit_blocks = []
audio_encoder_output = inputs["audio_encoder_output"]
audio_model_input = {
"audio_input_feat": audio_encoder_output.to(hidden_states.device),
"latent_shape": hidden_states.shape,
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
if positive:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
batch_size = len(x)
num_channels, num_frames, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device,
)
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=1)
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list
# embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea)
context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length)
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk from .utils import compute_freqs, compute_freqs_dist, compute_freqs_audio, compute_freqs_audio_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import ( from lightx2v.common.offload.manager import (
WeightAsyncStreamManager, WeightAsyncStreamManager,
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
) )
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
import pdb
import os
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
...@@ -64,10 +67,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -64,10 +67,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
...@@ -92,7 +95,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -92,7 +95,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
...@@ -133,7 +136,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -133,7 +136,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
...@@ -194,7 +197,22 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -194,7 +197,22 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def zero_temporal_component_in_3DRoPE(self, valid_token_length, rotary_emb=None):
if rotary_emb is None:
return None
self.use_real = False
rope_t_dim = 44
if self.use_real:
freqs_cos, freqs_sin = rotary_emb
freqs_cos[valid_token_length:, :, :rope_t_dim] = 0
freqs_sin[valid_token_length:, :, :rope_t_dim] = 0
return freqs_cos, freqs_sin
else:
freqs_cis = rotary_emb
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
weights.blocks[block_idx], weights.blocks[block_idx],
...@@ -206,6 +224,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -206,6 +224,12 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -265,14 +289,23 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -265,14 +289,23 @@ class WanTransformerInfer(BaseTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention: if not self.parallel_attention:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else: else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs) if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i) q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i) k = self.apply_rotary_emb_func(k, freqs_i)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens) k_lens = torch.empty_like(seq_lens).fill_(freqs_i.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias del freqs_i, norm1_out, norm1_weight, norm1_bias
...@@ -353,7 +386,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -353,7 +386,6 @@ class WanTransformerInfer(BaseTransformerInfer):
q, q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
img_attn_out = weights.cross_attn_2.apply( img_attn_out = weights.cross_attn_2.apply(
q=q, q=q,
k=k_img, k=k_img,
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0].tolist()
......
import os
import gc
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from loguru import logger
import torch.distributed as dist
from einops import rearrange
import torchaudio as ta
from transformers import AutoFeatureExtractor
from torchvision.datasets.folder import IMG_EXTENSIONS
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
import subprocess
import warnings
from typing import Optional, Tuple, Union
import pdb
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w
ori_ar = ori_h / ori_w
if abs(ori_ar - tgt_ar) < 0.01:
return 0, ori_h, 0, ori_w
if ori_ar > tgt_ar:
crop_h = int(tgt_ar * ori_w)
y0 = (ori_h - crop_h) // 2
y1 = y0 + crop_h
return y0, y1, 0, ori_w
else:
crop_w = int(ori_h / tgt_ar)
x0 = (ori_w - crop_w) // 2
x1 = x0 + crop_w
return 0, ori_h, x0, x1
def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h, ori_w = frames.shape[2:]
h, w = size
y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w)
cropped_frames = frames[:, :, y0:y1, x0:x1]
resized_frames = resize(cropped_frames, size, InterpolationMode.BICUBIC, antialias=True)
return resized_frames
def adaptive_resize(img):
bucket_config = {
0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])),
1.0: (np.array([[480, 480], [576, 576], [704, 704], [960, 960]], dtype=np.int64), np.array([0.1, 0.1, 0.5, 0.3])),
1.5: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64)[:, ::-1], np.array([0.2, 0.5, 0.3])),
}
ori_height = img.shape[-2]
ori_weight = img.shape[-1]
ori_ratio = ori_height / ori_weight
aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx]
target_h, target_w = 480, 832
for resolution in bucket_config[closet_ratio][0]:
if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution
cropped_img = isotropic_crop_resize(img, (target_h, target_w))
return cropped_img, target_h, target_w
def array_to_video(
image_array: np.ndarray,
output_path: str,
fps: Union[int, float] = 30,
resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
disable_log: bool = False,
lossless: bool = True,
) -> None:
if not isinstance(image_array, np.ndarray):
raise TypeError("Input should be np.ndarray.")
assert image_array.ndim == 4
assert image_array.shape[-1] == 3
if resolution:
height, width = resolution
width += width % 2
height += height % 2
else:
image_array = pad_for_libx264(image_array)
height, width = image_array.shape[1], image_array.shape[2]
if lossless:
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264rgb",
"-crf",
"0",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
else:
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
if not disable_log:
print(f'Running "{" ".join(command)}"')
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None or process.stderr is None:
raise BrokenPipeError("No buffer received.")
index = 0
while True:
if index >= image_array.shape[0]:
break
process.stdin.write(image_array[index].tobytes())
index += 1
process.stdin.close()
process.stderr.close()
process.wait()
def pad_for_libx264(image_array):
if image_array.ndim == 2 or (image_array.ndim == 3 and image_array.shape[2] == 3):
hei_index = 0
wid_index = 1
elif image_array.ndim == 4 or (image_array.ndim == 3 and image_array.shape[2] != 3):
hei_index = 1
wid_index = 2
else:
return image_array
hei_pad = image_array.shape[hei_index] % 2
wid_pad = image_array.shape[wid_index] % 2
if hei_pad + wid_pad > 0:
pad_width = []
for dim_index in range(image_array.ndim):
if dim_index == hei_index:
pad_width.append((0, hei_pad))
elif dim_index == wid_index:
pad_width.append((0, wid_pad))
else:
pad_width.append((0, 0))
values = 0
image_array = np.pad(image_array, pad_width, mode="constant", constant_values=values)
return image_array
def generate_unique_path(path):
if not os.path.exists(path):
return path
root, ext = os.path.splitext(path)
index = 1
new_path = f"{root}-{index}{ext}"
while os.path.exists(new_path):
index += 1
new_path = f"{root}-{index}{ext}"
return new_path
def save_to_video(gen_lvideo, out_path, target_fps):
print(gen_lvideo.shape)
gen_lvideo = rearrange(gen_lvideo, "B C T H W -> B T H W C")
gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
gen_lvideo = gen_lvideo[..., ::-1].copy()
generate_unique_path(out_path)
array_to_video(gen_lvideo, output_path=out_path, fps=target_fps, lossless=False)
def save_audio(
audio_array: str,
audio_name: str,
video_name: str = None,
sr: int = 16000,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
if not os.path.exists(audio_name):
ta.save(
audio_name,
torch.tensor(audio_array[None]),
sample_rate=sr,
)
out_video = f"{video_name[:-4]}_with_audio.mp4"
# generate_unique_path(out_path)
cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}"
subprocess.call(cmd, shell=True)
@RUNNER_REGISTER("wan2.1_audio")
class WanAudioRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
def load_audio_models(self):
self.audio_encoder = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
audio_adaper = AudioAdapter.from_transformer(
self.model,
audio_feature_dim=1024,
interval=1,
time_freq_dim=256,
projection_transformer_layers=4,
)
load_path = "/mnt/aigc/zoemodels/Zoetrained/vigendit/audio_driven/audio_adapter/audio_adapter_V1_0507_bf16.safetensors"
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, load_path, strict=False)
device = self.model.device
audio_encoder_repo = "/mnt/aigc/zoemodels/models--TencentGameMate--chinese-hubert-large/snapshots/90cb660492214f687e60f5ca509b20edae6e75bd"
audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
return audio_adapter_pipe
def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
return base_model
def run_image_encoder(self, config, vae_model):
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device)
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
# resize and crop image
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.visual([cond_frms.squeeze(0)[:, None, :, :]], config).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list): #
# list转tensor
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out
def run_input_encoder_internal(self):
image_encoder_output = None
if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"):
vae_encode_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
logger.info(f"clip_encoder_out:{clip_encoder_out.shape} vae_encode_out:{vae_encode_out.shape}")
with ProfilingContext("Run Text Encoder"):
with open(self.config["prompt_path"], "r", encoding="utf-8") as f:
prompt = f.readline().strip()
logger.info(f"Prompt: {prompt}")
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(prompt, img)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
gc.collect()
torch.cuda.empty_cache()
def set_target_shape(self):
ret = {}
num_channels_latents = 16
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
else:
error_msg = "t2v task is not supported in WanAudioRunner"
assert 1 == 0, error_msg
ret["target_shape"] = self.config.target_shape
return ret
def run(self):
def load_audio(in_path: str, sr: float = 16000):
audio_array, ori_sr = ta.load(in_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=sr)
return audio_array.numpy()
def get_audio_range(start_frame: int, end_frame: int, fps: float, audio_sr: float = 16000):
audio_frame_rate = audio_sr / fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio
audio_sr = 16000
max_num_frames = 81 # wan2.1一段最多81帧,5秒,16fps
target_fps = self.config.get("target_fps", 16) # 音视频同步帧率
video_duration = self.config.get("video_duration", 8) # 期望视频输出时长
audio_array = load_audio(self.config["audio_path"], sr=audio_sr)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
prev_frame_length = 5
prev_token_length = (prev_frame_length - 1) // 4 + 1
max_num_audio_length = int((max_num_frames + 1) / target_fps * 16000)
interval_num = 1
# expected_frames
expected_frames = min(max(1, int(float(video_duration) * target_fps)), audio_len)
res_frame_num = 0
if expected_frames <= max_num_frames:
interval_num = 1
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
interval_num += 1
audio_start, audio_end = get_audio_range(0, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array_ori = audio_array[audio_start:audio_end]
gen_video_list = []
cut_audio_list = []
# reference latents
tgt_h = self.config.tgt_h
tgt_w = self.config.tgt_w
device = self.model.scheduler.latents.device
dtype = torch.bfloat16
vae_dtype = torch.float
for idx in range(interval_num):
torch.manual_seed(42 + idx)
logger.info(f"### manual_seed: {42 + idx} ####")
useful_length = -1
if idx == 0: # 第一段 Condition padding0
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
audio_start, audio_end = get_audio_range(0, max_num_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
if expected_frames < max_num_frames:
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_encoder(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:]
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_encoder(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:]
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
audio_input_feat = self.audio_encoder(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
self.inputs["audio_encoder_output"] = audio_input_feat.to(device)
if idx != 0:
self.model.scheduler.reset()
if prev_latents is not None:
ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape
bs = 1
prev_mask = torch.zeros((bs, 1, nframe, height, width), device=device, dtype=dtype)
if prev_len > 0:
prev_mask[:, :, :prev_len] = 1.0
previmg_encoder_output = {
"prev_latents": prev_latents,
"prev_mask": prev_mask,
}
self.inputs["previmg_encoder_output"] = previmg_encoder_output
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
# gen_img = vae_handler.decode(xt.to(vae_dtype))
# B, C, T, H, W
gen_video = torch.clamp(gen_video, -1, 1)
start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
print(f"---- {idx}, {gen_video[:, :, start_frame:].shape}")
if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num])
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
elif expected_frames < max_num_frames and useful_length != -1:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames])
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:])
cut_audio_list.append(audio_array[start_audio_frame:])
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")
save_to_video(gen_lvideo, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path)
os.remove(out_path)
os.remove(audio_file)
async def run_pipeline(self):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.run_input_encoder_internal()
self.set_target_shape()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.run()
self.end_run()
torch.cuda.empty_cache()
gc.collect()
...@@ -115,6 +115,15 @@ class WanScheduler(BaseScheduler): ...@@ -115,6 +115,15 @@ class WanScheduler(BaseScheduler):
x0_pred = sample - sigma_t * model_output x0_pred = sample - sigma_t * model_output
return x0_pred return x0_pred
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
def multistep_uni_p_bh_update( def multistep_uni_p_bh_update(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
#!/bin/bash
# set path and first
lightx2v_path="/mnt/Text2Video/wangshankun/lightx2v"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-Audio-14B-720P/"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_audio.json \
--prompt_path ${lightx2v_path}/assets/inputs/audio/15.txt \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment