Commit 6060ff4f authored by wangshankun's avatar wangshankun
Browse files

Support:radial attention

parent b2147c40
{
"infer_steps": 8,
"infer_steps": 5,
"target_fps": 16,
"video_duration": 12,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"self_attn_1_type": "radial_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false,
"feature_caching": "Tea",
"coefficients": [
[8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02],
[-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
],
"use_ret_steps": true,
"teacache_thresh": 0.12
"cpu_offload": false
}
......@@ -2,6 +2,7 @@ from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2
from lightx2v.attentions.common.radial_attn import radial_attn
def attention(attention_type="flash_attn2", *args, **kwargs):
......@@ -13,5 +14,7 @@ def attention(attention_type="flash_attn2", *args, **kwargs):
return flash_attn3(*args, **kwargs)
elif attention_type == "sage_attn2":
return sage_attn2(*args, **kwargs)
elif attention_type == "radial_attn":
return radial_attn(*args, **kwargs)
else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
import torch
import flashinfer
###
### Code from radial-attention
### https://github.com/mit-han-lab/ç/blob/main/radial_attn/attn_mask.py#L150
###
def radial_attn(
query, key, value, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"
):
orig_seqlen, num_head, hidden_dim = query.shape
query = pad_qkv(query, block_size=block_size)
key = pad_qkv(key, block_size=block_size)
value = pad_qkv(value, block_size=block_size)
mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_cls) if mask_map else None
seqlen = query.shape[0]
workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
workspace_buffer,
backend="fa2",
)
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
indptr=indptr,
indices=indices,
M=seqlen,
N=seqlen,
R=block_size,
C=block_size,
num_qo_heads=num_head,
num_kv_heads=num_head,
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
o_data_type=query.dtype,
use_fp16_qk_reduction=True,
)
o = bsr_wrapper.run(query, key, value)
return o[:orig_seqlen, :, :]
def get_indptr_from_mask(mask, query):
# query shows the device of the indptr
# indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
# shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
# The first element is always 0, and the last element is the number of blocks in the row dimension.
# The rest of the elements are the number of blocks in each row.
# the mask is already a block sparse mask
indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
indptr[0] = 0
row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row]
indptr[1:] = torch.cumsum(row_counts, dim=0)
return indptr
def get_indices_from_mask(mask, query):
# indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
# shape `(nnz,),` where `nnz` is the number of non-zero blocks.
# The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
nonzero_indices = torch.nonzero(mask)
indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
return indices
def shrinkMaskStrict(mask, block_size=128):
seqlen = mask.shape[0]
block_num = seqlen // block_size
mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
col_densities = mask.sum(dim=1) / block_size
# we want the minimum non-zero column density in the block
non_zero_densities = col_densities > 0
high_density_cols = col_densities > 1 / 3
frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
block_mask = frac_high_density_cols > 0.6
block_mask[0:0] = True
block_mask[-1:-1] = True
return block_mask
def pad_qkv(input_tensor, block_size=128):
"""
Pad the input tensor to be a multiple of the block size.
input shape: (seqlen, num_heads, hidden_dim)
"""
seqlen, num_heads, hidden_dim = input_tensor.shape
# Calculate the necessary padding
padding_length = (block_size - (seqlen % block_size)) % block_size
# Create a padded tensor with zeros
padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
# Copy the original tensor into the padded tensor
padded_tensor[:seqlen, :, :] = input_tensor
return padded_tensor
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
if model_type == "wan":
if dist < 1:
return token_per_frame
if dist == 1:
return token_per_frame // 2
elif model_type == "hunyuan":
if dist <= 1:
return token_per_frame
else:
raise ValueError(f"Unknown model type: {model_type}")
group = dist.bit_length()
decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
threshold = block_size
if decay_length >= threshold:
return decay_length
else:
return threshold
def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
if j == 0: # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
remainder_col = (j * token_per_frame) % block_size
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
# set the block mask to the final log mask
block_row_start = (i * token_per_frame) // block_size
block_col_start = (j * token_per_frame) // block_size
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
return final_log_mask
class MaskMap:
def __init__(self, video_token_num=79200, num_frame=22):
self.video_token_num = video_token_num
self.num_frame = num_frame
self.log_mask = None
def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
if self.log_mask is None:
self.log_mask = gen_log_mask_shrinked(
query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size
)
block_bound = self.video_token_num // block_size
log_mask[:block_bound, :block_bound] = self.log_mask[:block_bound, :block_bound]
return log_mask
......@@ -37,6 +37,9 @@ else:
sageattn = None
from lightx2v.attentions.common.radial_attn import radial_attn
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
......@@ -70,7 +73,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func(
q,
k,
......@@ -88,7 +91,7 @@ class FlashAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, mask_map=None):
x = flash_attn_varlen_func_v3(
q,
k,
......@@ -101,6 +104,28 @@ class FlashAttn3Weight(AttnWeightTemplate):
return x
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_cls="wan"):
assert len(q.shape) == 3
x = radial_attn(
q,
k,
v,
mask_map=mask_map,
sparsity_type=sparsity_type,
block_size=block_size,
model_cls=model_cls[:3], # Use first 3 characters to match "wan", "wan2", etc.
decay_factor=decay_factor,
)
x = x.view(max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
......
......@@ -42,7 +42,9 @@ def init_runner(config):
async def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan")
parser.add_argument(
"--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox", "wan2.1_audio"], default="hunyuan"
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
......
......@@ -449,7 +449,7 @@ class WanVideoIPHandler:
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, subfolder="image_encoder", torch_dtype=dtype)
image_encoder = CLIPVisionModel.from_pretrained(repo_or_path, torch_dtype=dtype)
logger.info(f"Using image encoder {model_name} from {repo_or_path}")
image_encoder.requires_grad_(require_grad)
if mode == "eval":
......
......@@ -7,11 +7,9 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange
from transformers import AutoModel
from loguru import logger
import pdb
import os
import safetensors
from typing import List, Optional, Tuple, Union
......
......@@ -19,6 +19,8 @@ from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
......@@ -51,6 +53,15 @@ class WanAudioModel(WanModel):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
num_frame = c + 1 # for r2v
video_token_num = num_frame * (h // 2) * (w // 2)
from loguru import logger
logger.info(f"video_token_num: {video_token_num}, num_frame: {num_frame}")
self.transformer_infer.mask_map = MaskMap(video_token_num, num_frame)
embed, grid_sizes, pre_infer_out, valid_patch_length = self.pre_infer.infer(self.pre_weight, inputs, positive=True)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes, valid_patch_length)[0]
......
......@@ -7,7 +7,6 @@ from lightx2v.common.offload.manager import (
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
from loguru import logger
import pdb
import os
......@@ -24,6 +23,8 @@ class WanTransformerInfer(BaseTransformerInfer):
self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.mask_map = None
if self.config["cpu_offload"]:
if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"]
......@@ -290,7 +291,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if not self.parallel_attention:
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = compute_freqs_audio(q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
else:
......@@ -321,6 +322,7 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
else:
attn_out = self.parallel_attention(
......
......@@ -23,7 +23,7 @@ def compute_freqs(c, grid_sizes, freqs):
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f = f + 1
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
[
......
......@@ -50,6 +50,7 @@ class WanLoraWrapper:
self.model._init_weights(weight_dict)
logger.info(f"Applied LoRA: {lora_name} with alpha={alpha}")
del lora_weights # 删除节约显存
return True
@torch.no_grad()
......
......@@ -184,7 +184,7 @@ class WanSelfAttention(WeightModule):
sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt)
else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["self_attn_1_type"]]())
if self.quant_method in ["smoothquant", "awq"]:
self.add_module(
"smooth_norm1_weight",
......@@ -275,7 +275,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file,
),
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v":
self.add_module(
......@@ -304,7 +304,7 @@ class WanCrossAttention(WeightModule):
self.lazy_load_file,
),
)
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]())
class WanFFN(WeightModule):
......
......@@ -31,7 +31,6 @@ from torchvision.transforms.functional import resize
import subprocess
import warnings
from typing import Optional, Tuple, Union
import pdb
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
......@@ -210,7 +209,6 @@ def generate_unique_path(path):
def save_to_video(gen_lvideo, out_path, target_fps):
print(gen_lvideo.shape)
gen_lvideo = rearrange(gen_lvideo, "B C T H W -> B T H W C")
gen_lvideo = (gen_lvideo[0].cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
gen_lvideo = gen_lvideo[..., ::-1].copy()
......@@ -219,21 +217,29 @@ def save_to_video(gen_lvideo, out_path, target_fps):
def save_audio(
audio_array: str,
audio_array,
audio_name: str,
video_name: str = None,
sr: int = 16000,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
if not os.path.exists(audio_name):
ta.save(
audio_name,
torch.tensor(audio_array[None]),
sample_rate=sr,
)
ta.save(
audio_name,
torch.tensor(audio_array[None]),
sample_rate=sr,
)
out_video = f"{video_name[:-4]}_with_audio.mp4"
# generate_unique_path(out_path)
# 确保父目录存在
parent_dir = os.path.dirname(out_video)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
# 如果输出视频已存在,先删除
if os.path.exists(out_video):
os.remove(out_video)
cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}"
subprocess.call(cmd, shell=True)
......@@ -246,6 +252,9 @@ class WanAudioRunner(WanRunner):
def load_audio_models(self):
##音频特征提取器
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
##音频驱动视频生成adapter
audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors"
audio_adaper = AudioAdapter.from_transformer(
self.model,
audio_feature_dim=1024,
......@@ -253,11 +262,11 @@ class WanAudioRunner(WanRunner):
time_freq_dim=256,
projection_transformer_layers=4,
)
load_path = "/mnt/aigc/zoemodels/Zoetrained/vigendit/audio_driven/audio_adapter/audio_adapter_V1_0507_bf16.safetensors"
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, load_path, strict=False)
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, audio_adapter_path, strict=False)
##音频特征编码器
device = self.model.device
audio_encoder_repo = "/mnt/aigc/zoemodels/models--TencentGameMate--chinese-hubert-large/snapshots/90cb660492214f687e60f5ca509b20edae6e75bd"
audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
return audio_adapter_pipe
......@@ -275,9 +284,8 @@ class WanAudioRunner(WanRunner):
return base_model
def load_image_encoder(self):
image_encoder = WanVideoIPHandler(
"CLIPModel", repo_or_path="/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers", require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16
)
clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder
......@@ -325,6 +333,7 @@ class WanAudioRunner(WanRunner):
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect()
torch.cuda.empty_cache()
......@@ -360,15 +369,15 @@ class WanAudioRunner(WanRunner):
self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio
audio_sr = 16000
max_num_frames = 81 # wan2.1一段最多81帧,5秒,16fps
audio_sr = self.config.get("audio_sr", 16000)
max_num_frames = self.config.get("target_video_length", 81) # wan2.1一段最多81帧,5秒,16fps
target_fps = self.config.get("target_fps", 16) # 音视频同步帧率
video_duration = self.config.get("video_duration", 8) # 期望视频输出时长
video_duration = self.config.get("video_duration", 5) # 期望视频输出时长
audio_array = load_audio(self.config["audio_path"], sr=audio_sr)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
prev_frame_length = 5
prev_token_length = (prev_frame_length - 1) // 4 + 1
max_num_audio_length = int((max_num_frames + 1) / target_fps * 16000)
max_num_audio_length = int((max_num_frames + 1) / target_fps * audio_sr)
interval_num = 1
# expected_frames
......@@ -463,13 +472,10 @@ class WanAudioRunner(WanRunner):
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
# gen_img = vae_handler.decode(xt.to(vae_dtype))
# B, C, T, H, W
gen_video = torch.clamp(gen_video, -1, 1)
start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
print(f"---- {idx}, {gen_video[:, :, start_frame:].shape}")
if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num])
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
......@@ -482,7 +488,7 @@ class WanAudioRunner(WanRunner):
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = os.path.join("./", "video_merge.mp4")
out_path = self.config.save_video_path
audio_file = os.path.join("./", "audio_merge.wav")
save_to_video(gen_lvideo, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path)
......@@ -501,5 +507,5 @@ class WanAudioRunner(WanRunner):
self.run()
self.end_run()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
import math
import numpy as np
import torch
import gc
from typing import List, Optional, Tuple, Union
from lightx2v.models.schedulers.scheduler import BaseScheduler
......@@ -123,6 +124,8 @@ class WanScheduler(BaseScheduler):
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
def multistep_uni_p_bh_update(
self,
......
......@@ -27,6 +27,7 @@ export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
export DTYPE=BF16
python -m lightx2v.infer \
--model_cls wan2.1_audio \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment