Commit fad005dd authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge pull request #86 from ModelTC/dev/wan_audio

  feature: audio driven video gen
parents 973dd66b 6060ff4f
{
"infer_steps": 5,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "radial_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
}
......@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2
from lightx2v.attentions.common.radial_attn import radial_attn
def attention(attention_type="flash_attn2", *args, **kwargs):
......@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
return flash_attn3(*args, **kwargs)
elif attention_type == "sage_attn2":
return sage_attn2(*args, **kwargs)
elif attention_type == "radial_attn":
return radial_attn(*args, **kwargs)
else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
import torch
import flashinfer
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
def radial_attn(
query, key, value, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"
):
orig_seqlen, num_head, hidden_dim = query.shape
query = pad_qkv(query, block_size=block_size)
key = pad_qkv(key, block_size=block_size)
value = pad_qkv(value, block_size=block_size)
mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_cls) if mask_map else None
seqlen = query.shape[0]
workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
workspace_buffer,
backend="fa2",
)
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
indptr=indptr,
indices=indices,
M=seqlen,
N=seqlen,
R=block_size,
C=block_size,
num_qo_heads=num_head,
num_kv_heads=num_head,
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
o_data_type=query.dtype,
use_fp16_qk_reduction=True,
)
o = bsr_wrapper.run(query, key, value)
return o[:orig_seqlen, :, :]
def get_indptr_from_mask(mask, query):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
indptr[0] = 0
row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row]
indptr[1:] = torch.cumsum(row_counts, dim=0)
return indptr
def get_indices_from_mask(mask, query):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices = torch.nonzero(mask)
indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
return indices
def shrinkMaskStrict(mask, block_size=128):
seqlen = mask.shape[0]
block_num = seqlen // block_size
mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
col_densities = mask.sum(dim=1) / block_size
# we want the minimum non-zero column density in the block
non_zero_densities = col_densities > 0
high_density_cols = col_densities > 1 / 3
frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
block_mask = frac_high_density_cols > 0.6
block_mask[0:0] = True
block_mask[-1:-1] = True
return block_mask
def pad_qkv(input_tensor, block_size=128):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen, num_heads, hidden_dim = input_tensor.shape
# Calculate the necessary padding
padding_length = (block_size - (seqlen % block_size)) % block_size
# Create a padded tensor with zeros
padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
# Copy the original tensor into the padded tensor
padded_tensor[:seqlen, :, :] = input_tensor
return padded_tensor
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
if model_type == "wan":
if dist < 1:
return token_per_frame
if dist == 1:
return token_per_frame // 2
elif model_type == "hunyuan":
if dist <= 1:
return token_per_frame
else:
raise ValueError(f"Unknown model type: {model_type}")
group = dist.bit_length()
decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
threshold = block_size
if decay_length >= threshold:
return decay_length
else:
return threshold
def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
if j == 0: # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
remainder_col = (j * token_per_frame) % block_size
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
# set the block mask to the final log mask
block_row_start = (i * token_per_frame) // block_size
block_col_start = (j * token_per_frame) // block_size
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
return final_log_mask
class MaskMap:
def __init__(self, video_token_num=79200, num_frame=22):
self.video_token_num = video_token_num
self.num_frame = num_frame
self.log_mask = None
def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
if self.log_mask is None:
self.log_mask = gen_log_mask_shrinked(
query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size
)
block_bound = self.video_token_num // block_size
log_mask[:block_bound, :block_bound] = self.log_mask[:block_bound, :block_bound]
return log_mask
......@@ -37,6 +37,9 @@ else:
sageattn = None
from lightx2v.attentions.common.radial_attn import radial_attn
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
......@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func(
q,
k,
......@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func_v3(
q,
k,
......@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
return x
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"):
assert len(q.shape) == 3
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
......
......@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
......@@ -41,14 +42,19 @@ def init_runner(config):
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true")
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
......
......@@ -11,6 +11,9 @@ import torchvision.transforms as T
from lightx2v.attentions import attention
from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
__all__ = [
......@@ -428,3 +431,51 @@ class CLIPModel:
def to_cpu(self):
self.model = self.model.cpu()
class WanVideoIPHandler:
def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad)
if mode == "eval":
image_encoder.eval()
else:
image_encoder.train()
self.dtype = dtype
self.device = device
self.image_encoder = image_encoder.to(device=device, dtype=dtype)
self.size = (224, 224)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
self.normalize = T.Normalize(mean=mean, std=std)
# self.image_processor = image_processor
def encode(
self,
img_tensor: Tensor,
):
if img_tensor.ndim == 5: # B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
img_tensor = self.normalize(img_tensor).to(self.dtype)
image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)
return image_embeds.hidden_states[-1]
import flash_attn
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
from loguru import logger
import os
import safetensors
from typing import List, Optional, Tuple, Union
def load_safetensors(in_path: str):
if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path)
elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path)
else:
raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str):
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def load_safetensors_from_dir(in_dir: str):
tensors = {}
safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
return tensors
def load_pt_safetensors(in_path: str):
ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
else:
state_dict = load_safetensors(in_path)
return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
import torch.distributed as dist
if (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized()):
state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized():
dist.barrier()
return model.to(dtype=torch.bfloat16, device="cuda")
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear")
return output_features.transpose(1, 2)
def get_q_lens_audio_range(
batchsize,
n_tokens_per_rank,
n_query_tokens,
n_tokens_per_frame,
sp_rank,
):
if n_query_tokens == 0:
q_lens = [1] * batchsize
return q_lens, 0, 1
idx0 = n_tokens_per_rank * sp_rank
first_length = idx0 - idx0 // n_tokens_per_frame * n_tokens_per_frame
n_frames = (n_query_tokens - first_length) // n_tokens_per_frame
last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length
q_lens = []
if first_length > 0:
q_lens.append(first_length)
q_lens += [n_tokens_per_frame] * n_frames
if last_length > 0:
q_lens.append(last_length)
t0 = idx0 // n_tokens_per_frame
idx1 = idx0 + n_query_tokens
t1 = math.ceil(idx1 / n_tokens_per_frame)
return q_lens * batchsize, t0, t1
class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False):
super().__init__()
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
kv_dim = inner_dim if kv_dim is None else kv_dim
self.norm_kv = nn.LayerNorm(kv_dim)
self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN)
self.to_q = nn.Linear(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, inner_dim)
if adaLN:
self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5)
else:
shift_scale_gate = torch.zeros((1, 3, inner_dim))
shift_scale_gate[:, 2] = 1
self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)
def forward(self, x, latents, t_emb, q_lens, k_lens):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
batchsize = len(x)
x = self.norm_kv(x)
shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1)
latents = self.norm_q(latents) * (1 + scale) + shift
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads)
out = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=q_lens.max(),
max_seqlen_k=k_lens.max(),
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
)
out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize)
return self.to_out(out) * gate
class AudioProjection(nn.Module):
def __init__(
self,
audio_feature_dim: int = 768,
n_neighbors: tuple = (2, 2),
num_tokens: int = 32,
mlp_dims: tuple = (1024, 1024, 32 * 768),
transformer_layers: int = 4,
):
super().__init__()
mlp = []
self.left, self.right = n_neighbors
self.audio_frames = sum(n_neighbors) + 1
in_dim = audio_feature_dim * self.audio_frames
for i, out_dim in enumerate(mlp_dims):
mlp.append(nn.Linear(in_dim, out_dim))
if i != len(mlp_dims) - 1:
mlp.append(nn.ReLU())
in_dim = out_dim
self.mlp = nn.Sequential(*mlp)
self.norm = nn.LayerNorm(mlp_dims[-1] // num_tokens)
self.num_tokens = num_tokens
if transformer_layers > 0:
decoder_layer = nn.TransformerDecoderLayer(d_model=audio_feature_dim, nhead=audio_feature_dim // 64, dim_feedforward=4 * audio_feature_dim, dropout=0.0, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=transformer_layers,
)
else:
self.transformer_decoder = None
def forward(self, audio_feature, latent_frame):
video_frame = (latent_frame - 1) * 4 + 1
audio_feature_ori = audio_feature
audio_feature = linear_interpolation(audio_feature_ori, video_frame)
if self.transformer_decoder is not None:
audio_feature = self.transformer_decoder(audio_feature, audio_feature_ori)
audio_feature = F.pad(audio_feature, pad=(0, 0, self.left, self.right), mode="replicate")
audio_feature = audio_feature.unfold(dimension=1, size=self.audio_frames, step=1)
audio_feature = rearrange(audio_feature, "B T C W -> B T (W C)")
audio_feature = self.mlp(audio_feature) # (B, video_frame, C)
audio_feature = rearrange(audio_feature, "B T (N C) -> B T N C", N=self.num_tokens) # (B, video_frame, num_tokens, C)
return self.norm(audio_feature)
class TimeEmbedding(nn.Module):
def __init__(self, dim, time_freq_dim, time_proj_dim):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
def forward(
self,
timestep: torch.Tensor,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep)
timestep_proj = self.time_proj(self.act_fn(temb))
return timestep_proj
class AudioAdapter(nn.Module):
def __init__(
self,
attention_head_dim=64,
num_attention_heads=40,
base_num_layers=30,
interval=1,
audio_feature_dim: int = 768,
num_tokens: int = 32,
mlp_dims: tuple = (1024, 1024, 32 * 768),
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
super().__init__()
self.audio_proj = AudioProjection(
audio_feature_dim=audio_feature_dim,
n_neighbors=(2, 2),
num_tokens=num_tokens,
mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers,
)
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
ca_num = math.ceil(base_num_layers / interval)
self.base_num_layers = base_num_layers
self.interval = interval
self.ca = nn.ModuleList(
[
PerceiverAttentionCA(
dim_head=attention_head_dim,
heads=num_attention_heads,
kv_dim=mlp_dims[-1] // num_tokens,
adaLN=time_freq_dim > 0,
)
for _ in range(ca_num)
]
)
self.dim = attention_head_dim * num_attention_heads
if time_freq_dim > 0:
self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3)
else:
self.time_embedding = None
def rearange_audio_features(self, audio_feature: torch.Tensor):
# audio_feature (B, video_frame, num_tokens, C)
audio_feature_0 = audio_feature[:, :1]
audio_feature_0 = torch.repeat_interleave(audio_feature_0, repeats=4, dim=1)
audio_feature = torch.cat([audio_feature_0, audio_feature[:, 1:]], dim=1) # (B, 4 * latent_frame, num_tokens, C)
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
# print(weight)
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
tail_length = n_tokens_per_rank - n_tokens
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=0)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
x = x[:, t0:t1]
x = x.to(dtype)
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe
if self.time_embedding is not None:
t_emb = self.time_embedding(timestep).unflatten(1, (3, -1))
else:
t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype)
ret_dict = {}
for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)):
block_dict = {
"kwargs": {
"ca_block": self.ca[block_idx],
"x": x,
"weight": weight,
"t_emb": t_emb,
"dtype": x.dtype,
},
"modify_func": modify_hidden_states,
}
ret_dict[base_idx] = block_dict
return ret_dict
@classmethod
def from_transformer(
cls,
transformer,
audio_feature_dim: int = 1024,
interval: int = 1,
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
num_attention_heads = transformer.config["num_heads"]
base_num_layers = transformer.config["num_layers"]
attention_head_dim = transformer.config["dim"] // num_attention_heads
audio_adapter = AudioAdapter(
attention_head_dim,
num_attention_heads,
base_num_layers,
interval=interval,
audio_feature_dim=audio_feature_dim,
time_freq_dim=time_freq_dim,
projection_transformer_layers=projection_transformer_layers,
mlp_dims=(1024, 1024, 32 * audio_feature_dim),
)
return audio_adapter
def get_fsdp_wrap_module_list(
self,
):
ret_list = list(self.ca)
return ret_list
def enable_gradient_checkpointing(
self,
):
pass
class AudioAdapterPipe:
def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", generator=None, tgt_fps: int = 15, weight: float = 1.0
) -> None:
self.audio_adapter = audio_adapter
self.dtype = dtype
self.device = device
self.generator = generator
self.audio_encoder_dtype = torch.float16
##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
self.audio_encoder.eval()
self.audio_encoder.to(device, self.audio_encoder_dtype)
self.tgt_fps = tgt_fps
self.weight = weight
if "base" in audio_encoder_repo:
self.audio_feature_dim = 768
else:
self.audio_feature_dim = 1024
def update_model(self, audio_adapter):
self.audio_adapter = audio_adapter
def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None):
# audio_input_feat is from AudioPreprocessor
latent_frame = latent_shape[2]
if len(audio_input_feat.shape) == 1: # 扩展batchsize = 1
audio_input_feat = audio_input_feat.unsqueeze(0)
latent_frame = latent_shape[1]
video_frame = (latent_frame - 1) * 4 + 1
audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad():
audio_input_feat = audio_input_feat.to(self.device, self.audio_encoder_dtype)
try:
audio_feat = self.audio_encoder(audio_input_feat, return_dict=True).last_hidden_state
except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to(self.device)
print(err)
audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight)
import os
import torch
import time
import glob
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
)
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
from loguru import logger
logger.info(f"video_token_num: {video_token_num}, num_frame: {num_frame}")
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
import math
import torch
import torch.cuda.amp as amp
from loguru import logger
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
class WanAudioPostInfer(WanPostInfer):
def __init__(self, config):
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, x, e, grid_sizes, valid_patch_length):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out)
x = x[:, :valid_patch_length]
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
import torch
import math
from .utils import rope_params, sinusoidal_embedding_1d
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from loguru import logger
class WanAudioPreInfer(WanPreInfer):
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"]
self.task = config["task"]
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).cuda()
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
def infer(self, weights, inputs, positive):
ltnt_channel = self.scheduler.latents.size(0)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
hidden_states = hidden_states.squeeze(0)
x = [hidden_states]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
audio_dit_blocks = []
audio_encoder_output = inputs["audio_encoder_output"]
audio_model_input = {
"audio_input_feat": audio_encoder_output.to(hidden_states.device),
"latent_shape": hidden_states.shape,
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
if positive:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
batch_size = len(x)
num_channels, num_frames, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device,
)
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=1)
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list
# embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea)
context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length)
import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk
from .utils import compute_freqs, compute_freqs_dist, compute_freqs_audio, compute_freqs_audio_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import (
WeightAsyncStreamManager,
LazyWeightAsyncStreamManager,
)
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from loguru import logger
import os
class WanTransformerInfer(BaseTransformerInfer):
......@@ -21,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]:
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
......@@ -64,10 +68,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
......@@ -92,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(weights.blocks_num):
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
......@@ -133,7 +137,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num):
......@@ -194,7 +198,22 @@ class WanTransformerInfer(BaseTransformerInfer):
return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
def zero_temporal_component_in_3DRoPE(self, valid_token_length, rotary_emb=None):
if rotary_emb is None:
return None
self.use_real = False
rope_t_dim = 44
if self.use_real:
freqs_cos, freqs_sin = rotary_emb
freqs_cos[valid_token_length:, :, :rope_t_dim] = 0
freqs_sin[valid_token_length:, :, :rope_t_dim] = 0
return freqs_cos, freqs_sin
else:
freqs_cis = rotary_emb
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num):
x = self.infer_block(
weights.blocks[block_idx],
......@@ -206,6 +225,12 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs,
context,
)
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
......@@ -265,14 +290,23 @@ class WanTransformerInfer(BaseTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
k_lens = torch.empty_like(seq_lens).fill_(freqs_i.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias
......@@ -288,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
else:
attn_out = self.parallel_attention(
......@@ -353,7 +388,6 @@ class WanTransformerInfer(BaseTransformerInfer):
q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
)
img_attn_out = weights.cross_attn_2.apply(
q=q,
k=k_img,
......
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
......@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
......
......@@ -50,6 +50,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存
return True
@torch.no_grad()
......@@ -84,7 +85,8 @@ class WanLoraWrapper:
if name in lora_pairs:
if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu()
# import pdb
# pdb.set_trace()
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
......
......@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt)
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.quant_method in ["smoothquant", "awq"]:
self.add_module(
"smooth_norm1_weight",
......@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file,
),
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v":
self.add_module(
......@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file,
),
)
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]())
class WanFFN(WeightModule):
......
import os
import gc
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel, WanVideoIPHandler
from lightx2v.models.networks.wan.audio_model import WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from loguru import logger
import torch.distributed as dist
from einops import rearrange
import torchaudio as ta
from transformers import AutoFeatureExtractor
from torchvision.datasets.folder import IMG_EXTENSIONS
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
import subprocess
import warnings
from typing import Optional, Tuple, Union
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w
ori_ar = ori_h / ori_w
if abs(ori_ar - tgt_ar) < 0.01:
return 0, ori_h, 0, ori_w
if ori_ar > tgt_ar:
crop_h = int(tgt_ar * ori_w)
y0 = (ori_h - crop_h) // 2
y1 = y0 + crop_h
return y0, y1, 0, ori_w
else:
crop_w = int(ori_h / tgt_ar)
x0 = (ori_w - crop_w) // 2
x1 = x0 + crop_w
return 0, ori_h, x0, x1
def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
"""
frames: (T, C, H, W)
size: (H, W)
"""
ori_h, ori_w = frames.shape[2:]
h, w = size
y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w)
cropped_frames = frames[:, :, y0:y1, x0:x1]
resized_frames = resize(cropped_frames, size, InterpolationMode.BICUBIC, antialias=True)
return resized_frames
def adaptive_resize(img):
bucket_config = {
0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])),
1.0: (np.array([[480, 480], [576, 576], [704, 704], [960, 960]], dtype=np.int64), np.array([0.1, 0.1, 0.5, 0.3])),
1.5: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64)[:, ::-1], np.array([0.2, 0.5, 0.3])),
}
ori_height = img.shape[-2]
ori_weight = img.shape[-1]
ori_ratio = ori_height / ori_weight
aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx]
target_h, target_w = 480, 832
for resolution in bucket_config[closet_ratio][0]:
if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution
cropped_img = isotropic_crop_resize(img, (target_h, target_w))
return cropped_img, target_h, target_w
def array_to_video(
image_array: np.ndarray,
output_path: str,
fps: Union[int, float] = 30,
resolution: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None,
disable_log: bool = False,
lossless: bool = True,
) -> None:
if not isinstance(image_array, np.ndarray):
raise TypeError("Input should be np.ndarray.")
assert image_array.ndim == 4
assert image_array.shape[-1] == 3
if resolution:
height, width = resolution
width += width % 2
height += height % 2
else:
image_array = pad_for_libx264(image_array)
height, width = image_array.shape[1], image_array.shape[2]
if lossless:
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264rgb",
"-crf",
"0",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
else:
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
if not disable_log:
print(f'Running "{" ".join(command)}"')
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None or process.stderr is None:
raise BrokenPipeError("No buffer received.")
index = 0
while True:
if index >= image_array.shape[0]:
break
process.stdin.write(image_array[index].tobytes())
index += 1
process.stdin.close()
process.stderr.close()
process.wait()
def pad_for_libx264(image_array):
if image_array.ndim == 2 or (image_array.ndim == 3 and image_array.shape[2] == 3):
hei_index = 0
wid_index = 1
elif image_array.ndim == 4 or (image_array.ndim == 3 and image_array.shape[2] != 3):
hei_index = 1
wid_index = 2
else:
return image_array
hei_pad = image_array.shape[hei_index] % 2
wid_pad = image_array.shape[wid_index] % 2
if hei_pad + wid_pad > 0:
pad_width = []
for dim_index in range(image_array.ndim):
if dim_index == hei_index:
pad_width.append((0, hei_pad))
elif dim_index == wid_index:
pad_width.append((0, wid_pad))
else:
pad_width.append((0, 0))
values = 0
image_array = np.pad(image_array, pad_width, mode="constant", constant_values=values)
return image_array
def generate_unique_path(path):
if not os.path.exists(path):
return path
root, ext = os.path.splitext(path)
index = 1
new_path = f"{root}-{index}{ext}"
while os.path.exists(new_path):
index += 1
new_path = f"{root}-{index}{ext}"
return new_path
def save_to_video(gen_lvideo, out_path, target_fps):
gen_lvideo = rearrange(gen_lvideo, "B C T H W -> B T H W C")
gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
gen_lvideo = gen_lvideo[..., ::-1].copy()
generate_unique_path(out_path)
array_to_video(gen_lvideo, output_path=out_path, fps=target_fps, lossless=False)
def save_audio(
audio_array,
audio_name: str,
video_name: str = None,
sr: int = 16000,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
ta.save(
audio_name,
torch.tensor(audio_array[None]),
sample_rate=sr,
)
out_video = f"{video_name[:-4]}_with_audio.mp4"
# 确保父目录存在
parent_dir = os.path.dirname(out_video)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
# 如果输出视频已存在,先删除
if os.path.exists(out_video):
os.remove(out_video)
cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}"
subprocess.call(cmd, shell=True)
@RUNNER_REGISTER("wan2.1_audio")
class WanAudioRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
def load_audio_models(self):
##音频特征提取器
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
##音频驱动视频生成adapter
audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors"
audio_adaper = AudioAdapter.from_transformer(
self.model,
audio_feature_dim=1024,
interval=1,
time_freq_dim=256,
projection_transformer_layers=4,
)
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, audio_adapter_path, strict=False)
##音频特征编码器
device = self.model.device
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
return audio_adapter_pipe
def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.lora_path:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
lora_name = lora_wrapper.load_lora(self.config.lora_path)
lora_wrapper.apply_lora(lora_name, self.config.strength_model)
logger.info(f"Loaded LoRA: {lora_name}")
return base_model
def load_image_encoder(self):
clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder
def run_image_encoder(self, config, vae_model):
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device)
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
# resize and crop image
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list): #
# list转tensor
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out
def run_input_encoder_internal(self):
image_encoder_output = None
if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"):
vae_encode_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
logger.info(f"clip_encoder_out:{clip_encoder_out.shape} vae_encode_out:{vae_encode_out.shape}")
with ProfilingContext("Run Text Encoder"):
with open(self.config["prompt_path"], "r", encoding="utf-8") as f:
prompt = f.readline().strip()
logger.info(f"Prompt: {prompt}")
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(prompt, img)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect()
torch.cuda.empty_cache()
def set_target_shape(self):
ret = {}
num_channels_latents = 16
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
else:
error_msg = "t2v task is not supported in WanAudioRunner"
assert 1 == 0, error_msg
ret["target_shape"] = self.config.target_shape
return ret
def run(self):
def load_audio(in_path: str, sr: float = 16000):
audio_array, ori_sr = ta.load(in_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=sr)
return audio_array.numpy()
def get_audio_range(start_frame: int, end_frame: int, fps: float, audio_sr: float = 16000):
audio_frame_rate = audio_sr / fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio
audio_sr = self.config.get("audio_sr", 16000)
max_num_frames = self.config.get("target_video_length", 81) # wan2.1一段最多81帧,5秒,16fps
target_fps = self.config.get("target_fps", 16) # 音视频同步帧率
video_duration = self.config.get("video_duration", 5) # 期望视频输出时长
audio_array = load_audio(self.config["audio_path"], sr=audio_sr)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
prev_frame_length = 5
prev_token_length = (prev_frame_length - 1) // 4 + 1
max_num_audio_length = int((max_num_frames + 1) / target_fps * audio_sr)
interval_num = 1
# expected_frames
expected_frames = min(max(1, int(float(video_duration) * target_fps)), audio_len)
res_frame_num = 0
if expected_frames <= max_num_frames:
interval_num = 1
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
interval_num += 1
audio_start, audio_end = get_audio_range(0, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array_ori = audio_array[audio_start:audio_end]
gen_video_list = []
cut_audio_list = []
# reference latents
tgt_h = self.config.tgt_h
tgt_w = self.config.tgt_w
device = self.model.scheduler.latents.device
dtype = torch.bfloat16
vae_dtype = torch.float
for idx in range(interval_num):
torch.manual_seed(42 + idx)
logger.info(f"### manual_seed: {42 + idx} ####")
useful_length = -1
if idx == 0: # 第一段 Condition padding0
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
audio_start, audio_end = get_audio_range(0, max_num_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
if expected_frames < max_num_frames:
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:]
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:]
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
self.inputs["audio_encoder_output"] = audio_input_feat.to(device)
if idx != 0:
self.model.scheduler.reset()
if prev_latents is not None:
ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape
bs = 1
prev_mask = torch.zeros((bs, 1, nframe, height, width), device=device, dtype=dtype)
if prev_len > 0:
prev_mask[:, :, :prev_len] = 1.0
previmg_encoder_output = {
"prev_latents": prev_latents,
"prev_mask": prev_mask,
}
self.inputs["previmg_encoder_output"] = previmg_encoder_output
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1)
start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num])
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
elif expected_frames < max_num_frames and useful_length != -1:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames])
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:])
cut_audio_list.append(audio_array[start_audio_frame:])
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = self.config.save_video_path
audio_file = os.path.join("./", "audio_merge.wav")
save_to_video(gen_lvideo, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path)
os.remove(out_path)
os.remove(audio_file)
async def run_pipeline(self):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.run_input_encoder_internal()
self.set_target_shape()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.run()
self.end_run()
gc.collect()
torch.cuda.empty_cache()
import math
import numpy as np
import torch
import gc
from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.scheduler import BaseScheduler
......@@ -115,6 +116,17 @@ class WanScheduler(BaseScheduler):
x0_pred = sample - sigma_t * model_output
return x0_pred
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
def multistep_uni_p_bh_update(
self,
model_output: torch.Tensor,
......
#!/bin/bash
# set path and first
lightx2v_path="/mnt/Text2Video/wangshankun/lightx2v"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-Audio-14B-720P/"
lora_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_audio.json \
--prompt_path ${lightx2v_path}/assets/inputs/audio/15.txt \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4 \
--lora_path ${lora_path}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment