"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "b2816bca67ae5e47f1c285c5ee72929769932585"
Commit 6060ff4f authored by wangshankun's avatar wangshankun
Browse files

Support:radial attention

parent b2147c40
{ {
"infer_steps": 8, "infer_steps": 5,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81, "target_video_length": 81,
"target_height": 480, "target_height": 480,
"target_width": 832, "target_width": 832,
"attention_type": "flash_attn3", "self_attn_1_type": "radial_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42, "seed": 42,
"sample_guide_scale":1, "sample_guide_scale":1,
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": false, "cpu_offload": false
"feature_caching": "Tea",
"coefficients": [
[8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.12
} }
...@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa ...@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2 from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3 from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2 from lightx2v.attentions.common.sage_attn2 import sage_attn2
from lightx2v.attentions.common.radial_attn import radial_attn
def attention(attention_type="flash_attn2", *args, **kwargs): def attention(attention_type="flash_attn2", *args, **kwargs):
...@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs): ...@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
return flash_attn3(*args, **kwargs) return flash_attn3(*args, **kwargs)
elif attention_type == "sage_attn2": elif attention_type == "sage_attn2":
return sage_attn2(*args, **kwargs) return sage_attn2(*args, **kwargs)
elif attention_type == "radial_attn":
return radial_attn(*args, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}") raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
import torch
import flashinfer
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
def radial_attn(
query, key, value, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"
):
orig_seqlen, num_head, hidden_dim = query.shape
query = pad_qkv(query, block_size=block_size)
key = pad_qkv(key, block_size=block_size)
value = pad_qkv(value, block_size=block_size)
mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_cls) if mask_map else None
seqlen = query.shape[0]
workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
workspace_buffer,
backend="fa2",
)
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
indptr=indptr,
indices=indices,
M=seqlen,
N=seqlen,
R=block_size,
C=block_size,
num_qo_heads=num_head,
num_kv_heads=num_head,
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
o_data_type=query.dtype,
use_fp16_qk_reduction=True,
)
o = bsr_wrapper.run(query, key, value)
return o[:orig_seqlen, :, :]
def get_indptr_from_mask(mask, query):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
indptr[0] = 0
row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row]
indptr[1:] = torch.cumsum(row_counts, dim=0)
return indptr
def get_indices_from_mask(mask, query):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices = torch.nonzero(mask)
indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
return indices
def shrinkMaskStrict(mask, block_size=128):
seqlen = mask.shape[0]
block_num = seqlen // block_size
mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
col_densities = mask.sum(dim=1) / block_size
# we want the minimum non-zero column density in the block
non_zero_densities = col_densities > 0
high_density_cols = col_densities > 1 / 3
frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
block_mask = frac_high_density_cols > 0.6
block_mask[0:0] = True
block_mask[-1:-1] = True
return block_mask
def pad_qkv(input_tensor, block_size=128):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen, num_heads, hidden_dim = input_tensor.shape
# Calculate the necessary padding
padding_length = (block_size - (seqlen % block_size)) % block_size
# Create a padded tensor with zeros
padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
# Copy the original tensor into the padded tensor
padded_tensor[:seqlen, :, :] = input_tensor
return padded_tensor
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
if model_type == "wan":
if dist < 1:
return token_per_frame
if dist == 1:
return token_per_frame // 2
elif model_type == "hunyuan":
if dist <= 1:
return token_per_frame
else:
raise ValueError(f"Unknown model type: {model_type}")
group = dist.bit_length()
decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
threshold = block_size
if decay_length >= threshold:
return decay_length
else:
return threshold
def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
if j == 0: # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
remainder_col = (j * token_per_frame) % block_size
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
# set the block mask to the final log mask
block_row_start = (i * token_per_frame) // block_size
block_col_start = (j * token_per_frame) // block_size
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
return final_log_mask
class MaskMap:
def __init__(self, video_token_num=79200, num_frame=22):
self.video_token_num = video_token_num
self.num_frame = num_frame
self.log_mask = None
def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
if self.log_mask is None:
self.log_mask = gen_log_mask_shrinked(
query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size
)
block_bound = self.video_token_num // block_size
log_mask[:block_bound, :block_bound] = self.log_mask[:block_bound, :block_bound]
return log_mask
...@@ -37,6 +37,9 @@ else: ...@@ -37,6 +37,9 @@ else:
sageattn = None sageattn = None
from lightx2v.attentions.common.radial_attn import radial_attn
class AttnWeightTemplate(metaclass=ABCMeta): class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name): def __init__(self, weight_name):
self.weight_name = weight_name self.weight_name = weight_name
...@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate): ...@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func( x = flash_attn_varlen_func(
q, q,
k, k,
...@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func_v3( x = flash_attn_varlen_func_v3(
q, q,
k, k,
...@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate): ...@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
return x return x
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"):
assert len(q.shape) == 3
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn2") @ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate): class SageAttn2Weight(AttnWeightTemplate):
def __init__(self): def __init__(self):
......
...@@ -42,7 +42,9 @@ def init_runner(config): ...@@ -42,7 +42,9 @@ def init_runner(config):
async def main(): async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan") parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
......
...@@ -449,7 +449,7 @@ class WanVideoIPHandler: ...@@ -449,7 +449,7 @@ class WanVideoIPHandler:
but Wan2.1 official use no_crop resize by default but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor so I don't use CLIPImageProcessor
""" """
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, subfolder="image_encoder", torch_dtype=dtype) image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}") logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad) image_encoder.requires_grad_(require_grad)
if mode == "eval": if mode == "eval":
......
...@@ -7,11 +7,9 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps ...@@ -7,11 +7,9 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange from einops import rearrange
from transformers import AutoModel from transformers import AutoModel
from loguru import logger from loguru import logger
import pdb
import os import os
import safetensors import safetensors
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
......
...@@ -19,6 +19,8 @@ from safetensors import safe_open ...@@ -19,6 +19,8 @@ from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import ( from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer, WanTransformerInfer,
) )
...@@ -51,6 +53,15 @@ class WanAudioModel(WanModel): ...@@ -51,6 +53,15 @@ class WanAudioModel(WanModel):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
from loguru import logger
logger.info(f"video_token_num: {video_token_num}, num_frame: {num_frame}")
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True) embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0] noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
......
...@@ -7,7 +7,6 @@ from lightx2v.common.offload.manager import ( ...@@ -7,7 +7,6 @@ from lightx2v.common.offload.manager import (
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
import pdb
import os import os
...@@ -24,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -24,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.parallel_attention = None self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
...@@ -290,7 +291,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -290,7 +291,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if not self.parallel_attention: if not self.parallel_attention:
if self.config.get("audio_sr", False): if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio(q.size(0), q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else: else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else: else:
...@@ -321,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -321,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0), max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0), max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
mask_map=self.mask_map,
) )
else: else:
attn_out = self.parallel_attention( attn_out = self.parallel_attention(
......
...@@ -23,7 +23,7 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -23,7 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
def compute_freqs_audio(c, grid_sizes, freqs): def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0].tolist()
f = f + 1 f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
......
...@@ -50,6 +50,7 @@ class WanLoraWrapper: ...@@ -50,6 +50,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict) self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}") logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存
return True return True
@torch.no_grad() @torch.no_grad()
......
...@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule): ...@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
sparge_ckpt = torch.load(self.config["sparge_ckpt"]) sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt) self.self_attn_1.load(sparge_ckpt)
else: else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.add_module( self.add_module(
"smooth_norm1_weight", "smooth_norm1_weight",
...@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule): ...@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v": if self.config.task == "i2v":
self.add_module( self.add_module(
...@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule): ...@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]())
class WanFFN(WeightModule): class WanFFN(WeightModule):
......
...@@ -31,7 +31,6 @@ from torchvision.transforms.functional import resize ...@@ -31,7 +31,6 @@ from torchvision.transforms.functional import resize
import subprocess import subprocess
import warnings import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import pdb
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w): def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
...@@ -210,7 +209,6 @@ def generate_unique_path(path): ...@@ -210,7 +209,6 @@ def generate_unique_path(path):
def save_to_video(gen_lvideo, out_path, target_fps): def save_to_video(gen_lvideo, out_path, target_fps):
print(gen_lvideo.shape)
gen_lvideo = rearrange(gen_lvideo, "B C T H W -> B T H W C") gen_lvideo = rearrange(gen_lvideo, "B C T H W -> B T H W C")
gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8) gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
gen_lvideo = gen_lvideo[..., ::-1].copy() gen_lvideo = gen_lvideo[..., ::-1].copy()
...@@ -219,21 +217,29 @@ def save_to_video(gen_lvideo, out_path, target_fps): ...@@ -219,21 +217,29 @@ def save_to_video(gen_lvideo, out_path, target_fps):
def save_audio( def save_audio(
audio_array: str, audio_array,
audio_name: str, audio_name: str,
video_name: str = None, video_name: str = None,
sr: int = 16000, sr: int = 16000,
): ):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}") logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
if not os.path.exists(audio_name):
ta.save( ta.save(
audio_name, audio_name,
torch.tensor(audio_array[None]), torch.tensor(audio_array[None]),
sample_rate=sr, sample_rate=sr,
) )
out_video = f"{video_name[:-4]}_with_audio.mp4" out_video = f"{video_name[:-4]}_with_audio.mp4"
# generate_unique_path(out_path) # 确保父目录存在
parent_dir = os.path.dirname(out_video)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
# 如果输出视频已存在,先删除
if os.path.exists(out_video):
os.remove(out_video)
cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}" cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}"
subprocess.call(cmd, shell=True) subprocess.call(cmd, shell=True)
...@@ -246,6 +252,9 @@ class WanAudioRunner(WanRunner): ...@@ -246,6 +252,9 @@ class WanAudioRunner(WanRunner):
def load_audio_models(self): def load_audio_models(self):
##音频特征提取器 ##音频特征提取器
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder") self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
##音频驱动视频生成adapter
audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors"
audio_adaper = AudioAdapter.from_transformer( audio_adaper = AudioAdapter.from_transformer(
self.model, self.model,
audio_feature_dim=1024, audio_feature_dim=1024,
...@@ -253,11 +262,11 @@ class WanAudioRunner(WanRunner): ...@@ -253,11 +262,11 @@ class WanAudioRunner(WanRunner):
time_freq_dim=256, time_freq_dim=256,
projection_transformer_layers=4, projection_transformer_layers=4,
) )
load_path = "/mnt/aigc/zoemodels/Zoetrained/vigendit/audio_driven/audio_adapter/audio_adapter_V1_0507_bf16.safetensors" audio_adapter = rank0_load_state_dict_from_path(audio_adaper, audio_adapter_path, strict=False)
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, load_path, strict=False)
##音频特征编码器
device = self.model.device device = self.model.device
audio_encoder_repo = "/mnt/aigc/zoemodels/models--TencentGameMate--chinese-hubert-large/snapshots/90cb660492214f687e60f5ca509b20edae6e75bd" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0) audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
return audio_adapter_pipe return audio_adapter_pipe
...@@ -275,9 +284,8 @@ class WanAudioRunner(WanRunner): ...@@ -275,9 +284,8 @@ class WanAudioRunner(WanRunner):
return base_model return base_model
def load_image_encoder(self): def load_image_encoder(self):
image_encoder = WanVideoIPHandler( clip_model_dir = self.config["model_path"] + "/image_encoder"
"CLIPModel", repo_or_path="/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers", require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16 image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
)
return image_encoder return image_encoder
...@@ -325,6 +333,7 @@ class WanAudioRunner(WanRunner): ...@@ -325,6 +333,7 @@ class WanAudioRunner(WanRunner):
self.set_target_shape() self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output} self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -360,15 +369,15 @@ class WanAudioRunner(WanRunner): ...@@ -360,15 +369,15 @@ class WanAudioRunner(WanRunner):
self.inputs["audio_adapter_pipe"] = self.load_audio_models() self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio # process audio
audio_sr = 16000 audio_sr = self.config.get("audio_sr", 16000)
max_num_frames = 81 # wan2.1一段最多81帧,5秒,16fps max_num_frames = self.config.get("target_video_length", 81) # wan2.1一段最多81帧,5秒,16fps
target_fps = self.config.get("target_fps", 16) # 音视频同步帧率 target_fps = self.config.get("target_fps", 16) # 音视频同步帧率
video_duration = self.config.get("video_duration", 8) # 期望视频输出时长 video_duration = self.config.get("video_duration", 5) # 期望视频输出时长
audio_array = load_audio(self.config["audio_path"], sr=audio_sr) audio_array = load_audio(self.config["audio_path"], sr=audio_sr)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps) audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
prev_frame_length = 5 prev_frame_length = 5
prev_token_length = (prev_frame_length - 1) // 4 + 1 prev_token_length = (prev_frame_length - 1) // 4 + 1
max_num_audio_length = int((max_num_frames + 1) / target_fps * 16000) max_num_audio_length = int((max_num_frames + 1) / target_fps * audio_sr)
interval_num = 1 interval_num = 1
# expected_frames # expected_frames
...@@ -463,13 +472,10 @@ class WanAudioRunner(WanRunner): ...@@ -463,13 +472,10 @@ class WanAudioRunner(WanRunner):
latents = self.model.scheduler.latents latents = self.model.scheduler.latents
generator = self.model.scheduler.generator generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config) gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
# gen_img = vae_handler.decode(xt.to(vae_dtype))
# B, C, T, H, W
gen_video = torch.clamp(gen_video, -1, 1) gen_video = torch.clamp(gen_video, -1, 1)
start_frame = 0 if idx == 0 else prev_frame_length start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps) start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
print(f"---- {idx}, {gen_video[:, :, start_frame:].shape}")
if res_frame_num > 5 and idx == interval_num - 1: if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num]) gen_video_list.append(gen_video[:, :, start_frame:res_frame_num])
cut_audio_list.append(audio_array[start_audio_frame:useful_length]) cut_audio_list.append(audio_array[start_audio_frame:useful_length])
...@@ -482,7 +488,7 @@ class WanAudioRunner(WanRunner): ...@@ -482,7 +488,7 @@ class WanAudioRunner(WanRunner):
gen_lvideo = torch.cat(gen_video_list, dim=2).float() gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32) merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = os.path.join("./", "video_merge.mp4") out_path = self.config.save_video_path
audio_file = os.path.join("./", "audio_merge.wav") audio_file = os.path.join("./", "audio_merge.wav")
save_to_video(gen_lvideo, out_path, target_fps) save_to_video(gen_lvideo, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path) save_audio(merge_audio, audio_file, out_path)
...@@ -501,5 +507,5 @@ class WanAudioRunner(WanRunner): ...@@ -501,5 +507,5 @@ class WanAudioRunner(WanRunner):
self.run() self.run()
self.end_run() self.end_run()
torch.cuda.empty_cache()
gc.collect() gc.collect()
torch.cuda.empty_cache()
import math import math
import numpy as np import numpy as np
import torch import torch
import gc
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.models.schedulers.scheduler import BaseScheduler
...@@ -123,6 +124,8 @@ class WanScheduler(BaseScheduler): ...@@ -123,6 +124,8 @@ class WanScheduler(BaseScheduler):
self.this_order = None self.this_order = None
self.lower_order_nums = 0 self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
def multistep_uni_p_bh_update( def multistep_uni_p_bh_update(
self, self,
......
...@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH ...@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.infer \ python -m lightx2v.infer \
--model_cls wan2.1_audio \ --model_cls wan2.1_audio \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment