Commit fad005dd authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge pull request #86 from ModelTC/dev/wan_audio

  feature: audio driven video gen
parents 973dd66b 6060ff4f
{
"infer_steps": 5,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "radial_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
}
...@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa ...@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2 from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3 from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2 from lightx2v.attentions.common.sage_attn2 import sage_attn2
from lightx2v.attentions.common.radial_attn import radial_attn
def attention(attention_type="flash_attn2", *args, **kwargs): def attention(attention_type="flash_attn2", *args, **kwargs):
...@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs): ...@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
return flash_attn3(*args, **kwargs) return flash_attn3(*args, **kwargs)
elif attention_type == "sage_attn2": elif attention_type == "sage_attn2":
return sage_attn2(*args, **kwargs) return sage_attn2(*args, **kwargs)
elif attention_type == "radial_attn":
return radial_attn(*args, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}") raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
import torch
import flashinfer
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
def radial_attn(
query, key, value, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"
):
orig_seqlen, num_head, hidden_dim = query.shape
query = pad_qkv(query, block_size=block_size)
key = pad_qkv(key, block_size=block_size)
value = pad_qkv(value, block_size=block_size)
mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_cls) if mask_map else None
seqlen = query.shape[0]
workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
workspace_buffer,
backend="fa2",
)
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
indptr=indptr,
indices=indices,
M=seqlen,
N=seqlen,
R=block_size,
C=block_size,
num_qo_heads=num_head,
num_kv_heads=num_head,
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
o_data_type=query.dtype,
use_fp16_qk_reduction=True,
)
o = bsr_wrapper.run(query, key, value)
return o[:orig_seqlen, :, :]
def get_indptr_from_mask(mask, query):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
indptr[0] = 0
row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row]
indptr[1:] = torch.cumsum(row_counts, dim=0)
return indptr
def get_indices_from_mask(mask, query):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices = torch.nonzero(mask)
indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
return indices
def shrinkMaskStrict(mask, block_size=128):
seqlen = mask.shape[0]
block_num = seqlen // block_size
mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
col_densities = mask.sum(dim=1) / block_size
# we want the minimum non-zero column density in the block
non_zero_densities = col_densities > 0
high_density_cols = col_densities > 1 / 3
frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
block_mask = frac_high_density_cols > 0.6
block_mask[0:0] = True
block_mask[-1:-1] = True
return block_mask
def pad_qkv(input_tensor, block_size=128):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen, num_heads, hidden_dim = input_tensor.shape
# Calculate the necessary padding
padding_length = (block_size - (seqlen % block_size)) % block_size
# Create a padded tensor with zeros
padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
# Copy the original tensor into the padded tensor
padded_tensor[:seqlen, :, :] = input_tensor
return padded_tensor
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
if model_type == "wan":
if dist < 1:
return token_per_frame
if dist == 1:
return token_per_frame // 2
elif model_type == "hunyuan":
if dist <= 1:
return token_per_frame
else:
raise ValueError(f"Unknown model type: {model_type}")
group = dist.bit_length()
decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
threshold = block_size
if decay_length >= threshold:
return decay_length
else:
return threshold
def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
if j == 0: # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
remainder_col = (j * token_per_frame) % block_size
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
# set the block mask to the final log mask
block_row_start = (i * token_per_frame) // block_size
block_col_start = (j * token_per_frame) // block_size
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
return final_log_mask
class MaskMap:
def __init__(self, video_token_num=79200, num_frame=22):
self.video_token_num = video_token_num
self.num_frame = num_frame
self.log_mask = None
def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
if self.log_mask is None:
self.log_mask = gen_log_mask_shrinked(
query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size
)
block_bound = self.video_token_num // block_size
log_mask[:block_bound, :block_bound] = self.log_mask[:block_bound, :block_bound]
return log_mask
...@@ -37,6 +37,9 @@ else: ...@@ -37,6 +37,9 @@ else:
sageattn = None sageattn = None
from lightx2v.attentions.common.radial_attn import radial_attn
class AttnWeightTemplate(metaclass=ABCMeta): class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name): def __init__(self, weight_name):
self.weight_name = weight_name self.weight_name = weight_name
...@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate): ...@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func( x = flash_attn_varlen_func(
q, q,
k, k,
...@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func_v3( x = flash_attn_varlen_func_v3(
q, q,
k, k,
...@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
return x return x
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"):
assert len(q.shape) == 3
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn2") @ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate): class SageAttn2Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
......
...@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner ...@@ -14,6 +14,7 @@ from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner
from lightx2v.models.runners.wan.wan_audio_runner import WanAudioRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner
...@@ -41,14 +42,19 @@ def init_runner(config): ...@@ -41,14 +42,19 @@ def init_runner(config):
async def main(): async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan") parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true") parser.add_argument("--use_prompt_enhancer", action="store_true")
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--lora_path", type=str, default="", help="The lora file path")
parser.add_argument("--prompt_path", type=str, default="", help="The path to input prompt file")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -11,6 +11,9 @@ import torchvision.transforms as T ...@@ -11,6 +11,9 @@ import torchvision.transforms as T
from lightx2v.attentions import attention from lightx2v.attentions import attention
from loguru import logger from loguru import logger
from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8 from lightx2v.models.input_encoders.hf.q_linear import QuantLinearInt8, QuantLinearFp8
from einops import rearrange
from torch import Tensor
from transformers import CLIPVisionModel
__all__ = [ __all__ = [
...@@ -428,3 +431,51 @@ class CLIPModel: ...@@ -428,3 +431,51 @@ class CLIPModel:
def to_cpu(self): def to_cpu(self):
self.model = self.model.cpu() self.model = self.model.cpu()
class WanVideoIPHandler:
def __init__(self, model_name, repo_or_path, require_grad=False, mode="eval", device="cuda", dtype=torch.float16):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad)
if mode == "eval":
image_encoder.eval()
else:
image_encoder.train()
self.dtype = dtype
self.device = device
self.image_encoder = image_encoder.to(device=device, dtype=dtype)
self.size = (224, 224)
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
self.normalize = T.Normalize(mean=mean, std=std)
# self.image_processor = image_processor
def encode(
self,
img_tensor: Tensor,
):
if img_tensor.ndim == 5: # B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor = rearrange(img_tensor, "B C 1 H W -> B C H W")
img_tensor = torch.clamp(img_tensor.float() * 0.5 + 0.5, min=0.0, max=1.0).to(self.device)
img_tensor = F.interpolate(img_tensor, size=self.size, mode="bicubic", align_corners=False)
img_tensor = self.normalize(img_tensor).to(self.dtype)
image_embeds = self.image_encoder(pixel_values=img_tensor, output_hidden_states=True)
return image_embeds.hidden_states[-1]
import flash_attn
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
from loguru import logger
import os
import safetensors
from typing import List, Optional, Tuple, Union
def load_safetensors(in_path: str):
if os.path.isdir(in_path):
return load_safetensors_from_dir(in_path)
elif os.path.isfile(in_path):
return load_safetensors_from_path(in_path)
else:
raise ValueError(f"{in_path} does not exist")
def load_safetensors_from_path(in_path: str):
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
return tensors
def load_safetensors_from_dir(in_dir: str):
tensors = {}
safetensors = os.listdir(in_dir)
safetensors = [f for f in safetensors if f.endswith(".safetensors")]
for f in safetensors:
tensors.update(load_safetensors_from_path(os.path.join(in_dir, f)))
return tensors
def load_pt_safetensors(in_path: str):
ext = os.path.splitext(in_path)[-1]
if ext in (".pt", ".pth", ".tar"):
state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
else:
state_dict = load_safetensors(in_path)
return state_dict
def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True):
import torch.distributed as dist
if (dist.is_initialized() and dist.get_rank() == 0) or (not dist.is_initialized()):
state_dict = load_pt_safetensors(in_path)
model.load_state_dict(state_dict, strict=strict)
if dist.is_initialized():
dist.barrier()
return model.to(dtype=torch.bfloat16, device="cuda")
def linear_interpolation(features, output_len: int):
features = features.transpose(1, 2)
output_features = F.interpolate(features, size=output_len, align_corners=False, mode="linear")
return output_features.transpose(1, 2)
def get_q_lens_audio_range(
batchsize,
n_tokens_per_rank,
n_query_tokens,
n_tokens_per_frame,
sp_rank,
):
if n_query_tokens == 0:
q_lens = [1] * batchsize
return q_lens, 0, 1
idx0 = n_tokens_per_rank * sp_rank
first_length = idx0 - idx0 // n_tokens_per_frame * n_tokens_per_frame
n_frames = (n_query_tokens - first_length) // n_tokens_per_frame
last_length = n_query_tokens - n_frames * n_tokens_per_frame - first_length
q_lens = []
if first_length > 0:
q_lens.append(first_length)
q_lens += [n_tokens_per_frame] * n_frames
if last_length > 0:
q_lens.append(last_length)
t0 = idx0 // n_tokens_per_frame
idx1 = idx0 + n_query_tokens
t1 = math.ceil(idx1 / n_tokens_per_frame)
return q_lens * batchsize, t0, t1
class PerceiverAttentionCA(nn.Module):
def __init__(self, dim_head=128, heads=16, kv_dim=2048, adaLN: bool = False):
super().__init__()
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
kv_dim = inner_dim if kv_dim is None else kv_dim
self.norm_kv = nn.LayerNorm(kv_dim)
self.norm_q = nn.LayerNorm(inner_dim, elementwise_affine=not adaLN)
self.to_q = nn.Linear(inner_dim, inner_dim)
self.to_kv = nn.Linear(kv_dim, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, inner_dim)
if adaLN:
self.shift_scale_gate = nn.Parameter(torch.randn(1, 3, inner_dim) / inner_dim**0.5)
else:
shift_scale_gate = torch.zeros((1, 3, inner_dim))
shift_scale_gate[:, 2] = 1
self.register_buffer("shift_scale_gate", shift_scale_gate, persistent=False)
def forward(self, x, latents, t_emb, q_lens, k_lens):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
batchsize = len(x)
x = self.norm_kv(x)
shift, scale, gate = (t_emb + self.shift_scale_gate).chunk(3, dim=1)
latents = self.norm_q(latents) * (1 + scale) + shift
q = self.to_q(latents)
k, v = self.to_kv(x).chunk(2, dim=-1)
q = rearrange(q, "B L (H C) -> (B L) H C", H=self.heads)
k = rearrange(k, "B T L (H C) -> (B T L) H C", H=self.heads)
v = rearrange(v, "B T L (H C) -> (B T L) H C", H=self.heads)
out = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=q_lens.max(),
max_seqlen_k=k_lens.max(),
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
)
out = rearrange(out, "(B L) H C -> B L (H C)", B=batchsize)
return self.to_out(out) * gate
class AudioProjection(nn.Module):
def __init__(
self,
audio_feature_dim: int = 768,
n_neighbors: tuple = (2, 2),
num_tokens: int = 32,
mlp_dims: tuple = (1024, 1024, 32 * 768),
transformer_layers: int = 4,
):
super().__init__()
mlp = []
self.left, self.right = n_neighbors
self.audio_frames = sum(n_neighbors) + 1
in_dim = audio_feature_dim * self.audio_frames
for i, out_dim in enumerate(mlp_dims):
mlp.append(nn.Linear(in_dim, out_dim))
if i != len(mlp_dims) - 1:
mlp.append(nn.ReLU())
in_dim = out_dim
self.mlp = nn.Sequential(*mlp)
self.norm = nn.LayerNorm(mlp_dims[-1] // num_tokens)
self.num_tokens = num_tokens
if transformer_layers > 0:
decoder_layer = nn.TransformerDecoderLayer(d_model=audio_feature_dim, nhead=audio_feature_dim // 64, dim_feedforward=4 * audio_feature_dim, dropout=0.0, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer,
num_layers=transformer_layers,
)
else:
self.transformer_decoder = None
def forward(self, audio_feature, latent_frame):
video_frame = (latent_frame - 1) * 4 + 1
audio_feature_ori = audio_feature
audio_feature = linear_interpolation(audio_feature_ori, video_frame)
if self.transformer_decoder is not None:
audio_feature = self.transformer_decoder(audio_feature, audio_feature_ori)
audio_feature = F.pad(audio_feature, pad=(0, 0, self.left, self.right), mode="replicate")
audio_feature = audio_feature.unfold(dimension=1, size=self.audio_frames, step=1)
audio_feature = rearrange(audio_feature, "B T C W -> B T (W C)")
audio_feature = self.mlp(audio_feature) # (B, video_frame, C)
audio_feature = rearrange(audio_feature, "B T (N C) -> B T N C", N=self.num_tokens) # (B, video_frame, num_tokens, C)
return self.norm(audio_feature)
class TimeEmbedding(nn.Module):
def __init__(self, dim, time_freq_dim, time_proj_dim):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
def forward(
self,
timestep: torch.Tensor,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep)
timestep_proj = self.time_proj(self.act_fn(temb))
return timestep_proj
class AudioAdapter(nn.Module):
def __init__(
self,
attention_head_dim=64,
num_attention_heads=40,
base_num_layers=30,
interval=1,
audio_feature_dim: int = 768,
num_tokens: int = 32,
mlp_dims: tuple = (1024, 1024, 32 * 768),
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
super().__init__()
self.audio_proj = AudioProjection(
audio_feature_dim=audio_feature_dim,
n_neighbors=(2, 2),
num_tokens=num_tokens,
mlp_dims=mlp_dims,
transformer_layers=projection_transformer_layers,
)
# self.num_tokens = num_tokens * 4
self.num_tokens_x4 = num_tokens * 4
self.audio_pe = nn.Parameter(torch.randn(self.num_tokens_x4, mlp_dims[-1] // num_tokens) * 0.02)
ca_num = math.ceil(base_num_layers / interval)
self.base_num_layers = base_num_layers
self.interval = interval
self.ca = nn.ModuleList(
[
PerceiverAttentionCA(
dim_head=attention_head_dim,
heads=num_attention_heads,
kv_dim=mlp_dims[-1] // num_tokens,
adaLN=time_freq_dim > 0,
)
for _ in range(ca_num)
]
)
self.dim = attention_head_dim * num_attention_heads
if time_freq_dim > 0:
self.time_embedding = TimeEmbedding(self.dim, time_freq_dim, self.dim * 3)
else:
self.time_embedding = None
def rearange_audio_features(self, audio_feature: torch.Tensor):
# audio_feature (B, video_frame, num_tokens, C)
audio_feature_0 = audio_feature[:, :1]
audio_feature_0 = torch.repeat_interleave(audio_feature_0, repeats=4, dim=1)
audio_feature = torch.cat([audio_feature_0, audio_feature[:, 1:]], dim=1) # (B, 4 * latent_frame, num_tokens, C)
audio_feature = rearrange(audio_feature, "B (T S) N C -> B T (S N) C", S=4)
return audio_feature
def forward(self, audio_feat: torch.Tensor, timestep: torch.Tensor, latent_frame: int, weight: float = 1.0):
def modify_hidden_states(hidden_states, grid_sizes, ca_block: PerceiverAttentionCA, x, t_emb, dtype, weight):
"""thw specify the latent_frame, latent_height, latenf_width after
hidden_states is patchified.
latent_frame does not include the reference images so that the
audios and hidden_states are strictly aligned
"""
if len(hidden_states.shape) == 2: # 扩展batchsize dim
hidden_states = hidden_states.unsqueeze(0) # bs = 1
# print(weight)
t, h, w = grid_sizes[0].tolist()
n_tokens = t * h * w
ori_dtype = hidden_states.dtype
device = hidden_states.device
bs, n_tokens_per_rank = hidden_states.shape[:2]
tail_length = n_tokens_per_rank - n_tokens
n_query_tokens = n_tokens_per_rank - tail_length % n_tokens_per_rank
if n_query_tokens > 0:
hidden_states_aligned = hidden_states[:, :n_query_tokens]
hidden_states_tail = hidden_states[:, n_query_tokens:]
else:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned = hidden_states[:, :1]
hidden_states_tail = hidden_states[:, 1:]
q_lens, t0, t1 = get_q_lens_audio_range(batchsize=bs, n_tokens_per_rank=n_tokens_per_rank, n_query_tokens=n_query_tokens, n_tokens_per_frame=h * w, sp_rank=0)
q_lens = torch.tensor(q_lens, device=device, dtype=torch.int32)
"""
processing audio features in sp_state can be moved outside.
"""
x = x[:, t0:t1]
x = x.to(dtype)
k_lens = torch.tensor([self.num_tokens_x4] * (t1 - t0) * bs, device=device, dtype=torch.int32)
assert q_lens.shape == k_lens.shape
# ca_block:CrossAttention函数
residual = ca_block(x, hidden_states_aligned, t_emb, q_lens, k_lens) * weight
residual = residual.to(ori_dtype) # audio做了CrossAttention之后以Residual的方式注入
if n_query_tokens == 0:
residual = residual * 0.0
hidden_states = torch.cat([hidden_states_aligned + residual, hidden_states_tail], dim=1)
if len(hidden_states.shape) == 3: #
hidden_states = hidden_states.squeeze(0) # bs = 1
return hidden_states
x = self.audio_proj(audio_feat, latent_frame)
x = self.rearange_audio_features(x)
x = x + self.audio_pe
if self.time_embedding is not None:
t_emb = self.time_embedding(timestep).unflatten(1, (3, -1))
else:
t_emb = torch.zeros((len(x), 3, self.dim), device=x.device, dtype=x.dtype)
ret_dict = {}
for block_idx, base_idx in enumerate(range(0, self.base_num_layers, self.interval)):
block_dict = {
"kwargs": {
"ca_block": self.ca[block_idx],
"x": x,
"weight": weight,
"t_emb": t_emb,
"dtype": x.dtype,
},
"modify_func": modify_hidden_states,
}
ret_dict[base_idx] = block_dict
return ret_dict
@classmethod
def from_transformer(
cls,
transformer,
audio_feature_dim: int = 1024,
interval: int = 1,
time_freq_dim: int = 256,
projection_transformer_layers: int = 4,
):
num_attention_heads = transformer.config["num_heads"]
base_num_layers = transformer.config["num_layers"]
attention_head_dim = transformer.config["dim"] // num_attention_heads
audio_adapter = AudioAdapter(
attention_head_dim,
num_attention_heads,
base_num_layers,
interval=interval,
audio_feature_dim=audio_feature_dim,
time_freq_dim=time_freq_dim,
projection_transformer_layers=projection_transformer_layers,
mlp_dims=(1024, 1024, 32 * audio_feature_dim),
)
return audio_adapter
def get_fsdp_wrap_module_list(
self,
):
ret_list = list(self.ca)
return ret_list
def enable_gradient_checkpointing(
self,
):
pass
class AudioAdapterPipe:
def __init__(
self, audio_adapter: AudioAdapter, audio_encoder_repo: str = "microsoft/wavlm-base-plus", dtype=torch.float32, device="cuda", generator=None, tgt_fps: int = 15, weight: float = 1.0
) -> None:
self.audio_adapter = audio_adapter
self.dtype = dtype
self.device = device
self.generator = generator
self.audio_encoder_dtype = torch.float16
##音频编码器
self.audio_encoder = AutoModel.from_pretrained(audio_encoder_repo)
self.audio_encoder.eval()
self.audio_encoder.to(device, self.audio_encoder_dtype)
self.tgt_fps = tgt_fps
self.weight = weight
if "base" in audio_encoder_repo:
self.audio_feature_dim = 768
else:
self.audio_feature_dim = 1024
def update_model(self, audio_adapter):
self.audio_adapter = audio_adapter
def __call__(self, audio_input_feat, timestep, latent_shape: tuple, dropout_cond: callable = None):
# audio_input_feat is from AudioPreprocessor
latent_frame = latent_shape[2]
if len(audio_input_feat.shape) == 1: # 扩展batchsize = 1
audio_input_feat = audio_input_feat.unsqueeze(0)
latent_frame = latent_shape[1]
video_frame = (latent_frame - 1) * 4 + 1
audio_length = int(50 / self.tgt_fps * video_frame)
with torch.no_grad():
audio_input_feat = audio_input_feat.to(self.device, self.audio_encoder_dtype)
try:
audio_feat = self.audio_encoder(audio_input_feat, return_dict=True).last_hidden_state
except Exception as err:
audio_feat = torch.rand(1, audio_length, self.audio_feature_dim).to(self.device)
print(err)
audio_feat = audio_feat.to(self.dtype)
if dropout_cond is not None:
audio_feat = dropout_cond(audio_feat)
return self.audio_adapter(audio_feat=audio_feat, timestep=timestep, latent_frame=latent_frame, weight=self.weight)
import os
import torch
import time
import glob
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import (
WanTransformerInferTeaCaching,
)
class WanAudioModel(WanModel):
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanAudioPreInfer
self.post_infer_class = WanAudioPostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = WanTransformerInfer
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = WanTransformerInferTeaCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config['feature_caching']}")
@torch.no_grad()
def infer(self, inputs):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
from loguru import logger
logger.info(f"video_token_num: {video_token_num}, num_frame: {num_frame}")
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_cond
if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
self.scheduler.cnt = 0
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
import math
import torch
import torch.cuda.amp as amp
from loguru import logger
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
class WanAudioPostInfer(WanPostInfer):
def __init__(self, config):
self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, x, e, grid_sizes, valid_patch_length):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.tensor.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out)
x = x[:, :valid_patch_length]
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def unpatchify(self, x, grid_sizes):
x = x.unsqueeze(0)
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[: math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum("fhwpqrc->cfphqwr", u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
import torch
import math
from .utils import rope_params, sinusoidal_embedding_1d
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from loguru import logger
class WanAudioPreInfer(WanPreInfer):
def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"]
self.task = config["task"]
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).cuda()
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.text_len = config["text_len"]
def infer(self, weights, inputs, positive):
ltnt_channel = self.scheduler.latents.size(0)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
hidden_states = hidden_states.squeeze(0)
x = [hidden_states]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
audio_dit_blocks = []
audio_encoder_output = inputs["audio_encoder_output"]
audio_model_input = {
"audio_input_feat": audio_encoder_output.to(hidden_states.device),
"latent_shape": hidden_states.shape,
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
if positive:
context = inputs["text_encoder_output"]["context"]
else:
context = inputs["text_encoder_output"]["context_null"]
seq_len = self.scheduler.seq_len
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
ref_image_encoder = inputs["image_encoder_output"]["vae_encode_out"]
batch_size = len(x)
num_channels, num_frames, height, width = x[0].shape
_, ref_num_channels, ref_num_frames, _, _ = ref_image_encoder.shape
if ref_num_channels != num_channels:
zero_padding = torch.zeros(
(batch_size, num_channels - ref_num_channels, ref_num_frames, height, width),
dtype=self.scheduler.latents.dtype,
device=self.scheduler.latents.device,
)
ref_image_encoder = torch.concat([ref_image_encoder, zero_padding], dim=1)
y = list(torch.unbind(ref_image_encoder, dim=0)) # 第一个batch维度变成list
# embeddings
x = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in x]
x_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long).cuda()
assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
valid_patch_length = x[0].size(0)
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
# text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out)
if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea)
context_clip = weights.proj_1.apply(context_clip)
context_clip = torch.nn.functional.gelu(context_clip, approximate="none")
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
return (embed, x_grid_sizes, (x.squeeze(0), embed0.squeeze(0), seq_lens, self.freqs, context, audio_dit_blocks), valid_patch_length)
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk from .utils import compute_freqs, compute_freqs_dist, compute_freqs_audio, compute_freqs_audio_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import ( from lightx2v.common.offload.manager import (
WeightAsyncStreamManager, WeightAsyncStreamManager,
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
) )
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
import os
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
...@@ -21,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -21,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.parallel_attention = None self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
...@@ -64,10 +68,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -64,10 +68,10 @@ class WanTransformerInfer(BaseTransformerInfer):
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks)
def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = weights.blocks[0] self.weights_stream_mgr.active_weights[0] = weights.blocks[0]
...@@ -92,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -92,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
...@@ -133,7 +137,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -133,7 +137,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_lazy_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
...@@ -194,7 +198,22 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -194,7 +198,22 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def zero_temporal_component_in_3DRoPE(self, valid_token_length, rotary_emb=None):
if rotary_emb is None:
return None
self.use_real = False
rope_t_dim = 44
if self.use_real:
freqs_cos, freqs_sin = rotary_emb
freqs_cos[valid_token_length:, :, :rope_t_dim] = 0
freqs_sin[valid_token_length:, :, :rope_t_dim] = 0
return freqs_cos, freqs_sin
else:
freqs_cis = rotary_emb
freqs_cis[valid_token_length:, :, : rope_t_dim // 2] = 0
return freqs_cis
def _infer_without_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context, audio_dit_blocks=None):
for block_idx in range(self.blocks_num): for block_idx in range(self.blocks_num):
x = self.infer_block( x = self.infer_block(
weights.blocks[block_idx], weights.blocks[block_idx],
...@@ -206,6 +225,12 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -206,6 +225,12 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs, freqs,
context, context,
) )
if audio_dit_blocks is not None and len(audio_dit_blocks) > 0:
for ipa_out in audio_dit_blocks:
if block_idx in ipa_out:
cur_modify = ipa_out[block_idx]
x = cur_modify["modify_func"](x, grid_sizes, **cur_modify["kwargs"])
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
...@@ -265,14 +290,23 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -265,14 +290,23 @@ class WanTransformerInfer(BaseTransformerInfer):
v = weights.self_attn_v.apply(norm1_out).view(s, n, d) v = weights.self_attn_v.apply(norm1_out).view(s, n, d)
if not self.parallel_attention: if not self.parallel_attention:
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else: else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs) if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
q = self.apply_rotary_emb_func(q, freqs_i) q = self.apply_rotary_emb_func(q, freqs_i)
k = self.apply_rotary_emb_func(k, freqs_i) k = self.apply_rotary_emb_func(k, freqs_i)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens) k_lens = torch.empty_like(seq_lens).fill_(freqs_i.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias del freqs_i, norm1_out, norm1_weight, norm1_bias
...@@ -288,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -288,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0), max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0), max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
mask_map=self.mask_map,
) )
else: else:
attn_out = self.parallel_attention( attn_out = self.parallel_attention(
...@@ -353,7 +388,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -353,7 +388,6 @@ class WanTransformerInfer(BaseTransformerInfer):
q, q,
k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device), k_lens=torch.tensor([k_img.size(0)], dtype=torch.int32, device=k.device),
) )
img_attn_out = weights.cross_attn_2.apply( img_attn_out = weights.cross_attn_2.apply(
q=q, q=q,
k=k_img, k=k_img,
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -19,6 +20,45 @@ def compute_freqs(c, grid_sizes, freqs):
return freqs_i return freqs_i
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
return freqs_i
def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
seq_len = f * h * w
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
freqs_i = pad_freqs(freqs_i, s * world_size)
s_per_rank = s
freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
return freqs_i_rank
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0].tolist()
......
...@@ -50,6 +50,7 @@ class WanLoraWrapper: ...@@ -50,6 +50,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict) self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存
return True return True
@torch.no_grad() @torch.no_grad()
...@@ -84,7 +85,8 @@ class WanLoraWrapper: ...@@ -84,7 +85,8 @@ class WanLoraWrapper:
if name in lora_pairs: if name in lora_pairs:
if name not in self.override_dict: if name not in self.override_dict:
self.override_dict[name] = param.clone().cpu() self.override_dict[name] = param.clone().cpu()
# import pdb
# pdb.set_trace()
name_lora_A, name_lora_B = lora_pairs[name] name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype) lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype) lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
......
...@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule): ...@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
sparge_ckpt = torch.load(self.config["sparge_ckpt"]) sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt) self.self_attn_1.load(sparge_ckpt)
else: else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.add_module( self.add_module(
"smooth_norm1_weight", "smooth_norm1_weight",
...@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule): ...@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v": if self.config.task == "i2v":
self.add_module( self.add_module(
...@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule): ...@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]())
class WanFFN(WeightModule): class WanFFN(WeightModule):
......
This diff is collapsed.
import math import math
import numpy as np import numpy as np
import torch import torch
import gc
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
...@@ -115,6 +116,17 @@ class WanScheduler(BaseScheduler): ...@@ -115,6 +116,17 @@ class WanScheduler(BaseScheduler):
x0_pred = sample - sigma_t * model_output x0_pred = sample - sigma_t * model_output
return x0_pred return x0_pred
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
def multistep_uni_p_bh_update( def multistep_uni_p_bh_update(
self, self,
model_output: torch.Tensor, model_output: torch.Tensor,
......
#!/bin/bash
# set path and first
lightx2v_path="/mnt/Text2Video/wangshankun/lightx2v"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan2.1-I2V-Audio-14B-720P/"
lora_path="/mnt/Text2Video/wangshankun/HF_Cache/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v_audio.json \
--prompt_path ${lightx2v_path}/assets/inputs/audio/15.txt \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4 \
--lora_path ${lora_path}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment