Commit 5ec2b691 authored by sandy's avatar sandy Committed by GitHub
Browse files

Merge pull request #135 from ModelTC/audio_r2v

Audio r2v v2
parents 6d07a72e e08c4f90
{
"infer_steps": 5,
"infer_steps": 4,
"target_fps": 16,
"video_duration": 12,
"video_duration": 16,
"audio_sr": 16000,
"target_video_length": 81,
"target_height": 480,
......@@ -13,5 +13,6 @@
"sample_guide_scale":1,
"sample_shift": 5,
"enable_cfg": false,
"cpu_offload": false
"cpu_offload": false,
"use_tiling_vae": true
}
......@@ -24,13 +24,15 @@ class WanAudioPreInfer(WanPreInfer):
self.text_len = config["text_len"]
def infer(self, weights, inputs, positive):
ltnt_channel = self.scheduler.latents.size(0)
ltnt_frames = self.scheduler.latents.size(1)
prev_latents = inputs["previmg_encoder_output"]["prev_latents"].unsqueeze(0)
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = self.scheduler.latents.unsqueeze(0)
hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
# hidden_states = torch.cat([hidden_states[:, :ltnt_channel], prev_latents, prev_mask], dim=1)
# print(f"{prev_mask.shape}, {hidden_states.shape}, {prev_latents.shape},{prev_latents[:, :, :ltnt_frames].shape}")
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents[:, :, :ltnt_frames]], dim=1)
hidden_states = hidden_states.squeeze(0)
x = [hidden_states]
......@@ -44,6 +46,7 @@ class WanAudioPreInfer(WanPreInfer):
"timestep": t,
}
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
##audio_dit_blocks = None##Debug Drop Audio
if positive:
context = inputs["text_encoder_output"]["context"]
......
......@@ -18,6 +18,9 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.models.schedulers.wan.audio.scheduler import EulerSchedulerTimestepFix
from loguru import logger
import torch.distributed as dist
from einops import rearrange
......@@ -33,6 +36,45 @@ import warnings
from typing import Optional, Tuple, Union
def add_mask_to_frames(
frames: np.ndarray,
mask_rate: float = 0.1,
rnd_state: np.random.RandomState = None,
) -> np.ndarray:
if mask_rate is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
h, w = frames.shape[-2:]
mask = rnd_state.rand(h, w) > mask_rate
frames = frames * mask
return frames
def add_noise_to_frames(
frames: np.ndarray,
noise_mean: float = -3.0,
noise_std: float = 0.5,
rnd_state: np.random.RandomState = None,
) -> np.ndarray:
if noise_mean is None or noise_std is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
shape = frames.shape
bs = 1 if len(shape) == 4 else shape[0]
sigma = rnd_state.normal(loc=noise_mean, scale=noise_std, size=(bs,))
sigma = np.exp(sigma)
sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape))))
noise = rnd_state.randn(*shape) * sigma
frames = frames + noise
return frames
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w
ori_ar = ori_h / ori_w
......@@ -75,7 +117,12 @@ def adaptive_resize(img):
aspect_ratios = np.array(np.array(list(bucket_config.keys())))
closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
closet_ratio = aspect_ratios[closet_aspect_idx]
target_h, target_w = 480, 832
if ori_ratio < 1.0:
target_h, target_w = 480, 832
elif ori_ratio == 1.0:
target_h, target_w = 480, 480
else:
target_h, target_w = 832, 480
for resolution in bucket_config[closet_ratio][0]:
if ori_height * ori_weight >= resolution[0] * resolution[1]:
target_h, target_w = resolution
......@@ -253,6 +300,10 @@ class WanAudioRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
def init_scheduler(self):
scheduler = EulerSchedulerTimestepFix(self.config)
self.model.set_scheduler(scheduler)
def load_audio_models(self):
##音频特征提取器
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
......@@ -372,6 +423,18 @@ class WanAudioRunner(WanRunner):
audio_frame_rate = audio_sr / fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
def wan_mask_rearrange(mask: torch.Tensor):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if mask.ndim == 3:
mask = mask[None]
assert mask.ndim == 4
_, t, h, w = mask.shape
assert t == ((t - 1) // 4 * 4 + 1)
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1) # 4, T // 4, H, W
self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio
......@@ -427,7 +490,14 @@ class WanAudioRunner(WanRunner):
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:]
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr)
......@@ -438,7 +508,14 @@ class WanAudioRunner(WanRunner):
else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_frames[:, :, :prev_frame_length] = gen_video_list[-1][:, :, -prev_frame_length:]
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
......@@ -452,11 +529,11 @@ class WanAudioRunner(WanRunner):
if prev_latents is not None:
ltnt_channel, nframe, height, width = self.model.scheduler.latents.shape
bs = 1
prev_mask = torch.zeros((bs, 1, nframe, height, width), device=device, dtype=dtype)
if prev_len > 0:
prev_mask[:, :, :prev_len] = 1.0
# bs = 1
frames_n = (nframe - 1) * 4 + 1
prev_mask = torch.zeros((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_len:] = 0
prev_mask = wan_mask_rearrange(prev_mask).unsqueeze(0)
previmg_encoder_output = {
"prev_latents": prev_latents,
"prev_mask": prev_mask,
......@@ -483,13 +560,13 @@ class WanAudioRunner(WanRunner):
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num])
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
elif expected_frames < max_num_frames and useful_length != -1:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames])
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:])
gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(audio_array[start_audio_frame:])
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
......
import os
import gc
import math
import numpy as np
import torch
from typing import List, Optional, Tuple, Union
from lightx2v.utils.envs import *
from diffusers.configuration_utils import register_to_config
from torch import Tensor
from .utils import unsqueeze_to_ndim
from diffusers import (
FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore
)
def get_timesteps(num_steps, max_steps: int = 1000):
return np.linspace(max_steps, 0, num_steps + 1, dtype=np.float32)
def timestep_shift(timesteps, shift: float = 1.0):
return shift * timesteps / (1 + (shift - 1) * timesteps)
class FlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteSchedulerBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.init_noise_sigma = 1.0
def add_noise(self, x0: Tensor, noise: Tensor, timesteps: Tensor):
dtype = x0.dtype
device = x0.device
sigma = timesteps.to(device, torch.float32) / self.config.num_train_timesteps
sigma = unsqueeze_to_ndim(sigma, x0.ndim)
xt = x0.float() * (1 - sigma) + noise.float() * sigma
return xt.to(dtype)
def get_velocity(self, x0: Tensor, noise: Tensor, timesteps: Tensor | None = None):
return noise - x0
def velocity_loss_to_x_loss(self, v_loss: Tensor, timesteps: Tensor):
device = v_loss.device
sigma = timesteps.to(device, torch.float32) / self.config.num_train_timesteps
return v_loss.float() * (sigma**2)
class EulerSchedulerTimestepFix(FlowMatchEulerDiscreteScheduler):
def __init__(self, config):
self.config = config
self.step_index = 0
self.latents = None
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.num_train_timesteps = 1000
self.noise_pred = None
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
def set_shift(self, shift: float = 1.0):
self.sigmas = self.timesteps_ori / self.num_train_timesteps
self.sigmas = timestep_shift(self.sigmas, shift=shift)
self.timesteps = self.sigmas * self.num_train_timesteps
def set_timesteps(
self,
infer_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
timesteps = get_timesteps(num_steps=infer_steps, max_steps=self.num_train_timesteps)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device or self.device)
self.timesteps_ori = self.timesteps.clone()
self.set_shift(self.sample_shift)
self._step_index = None
self._begin_index = None
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if os.path.isfile(self.config.image_path):
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
else:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.set_timesteps(infer_steps=self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
timestep = self.timesteps[self.step_index]
sample = self.latents.to(torch.float32)
if self.step_index is None:
self._init_step_index(timestep)
sample = sample.to(torch.float32) # pyright: ignore
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim)
# x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output
self._step_index += 1
return x_t_next
def reset(self):
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
import os
import gc
import math
import numpy as np
import torch
from typing import List, Optional, Tuple, Union
from lightx2v.utils.envs import *
from lightx2v.models.schedulers.scheduler import BaseScheduler
from loguru import logger
from diffusers.configuration_utils import register_to_config
from torch import Tensor
from diffusers import (
FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase, # pyright: ignore
)
def unsqueeze_to_ndim(in_tensor: Tensor, tgt_n_dim: int):
if in_tensor.ndim > tgt_n_dim:
warnings.warn(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
return in_tensor
if in_tensor.ndim < tgt_n_dim:
in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
return in_tensor
class EulerSchedulerTimestepFix(BaseScheduler):
def __init__(self, config, **kwargs):
# super().__init__(**kwargs)
self.init_noise_sigma = 1.0
self.config = config
self.latents = None
self.caching_records = [True] * config.infer_steps
self.flag_df = False
self.transformer_infer = None
self.solver_order = 2
self.device = torch.device("cuda")
self.infer_steps = self.config.infer_steps
self.target_video_length = self.config.target_video_length
self.sample_shift = self.config.sample_shift
self.shift = 1
self.num_train_timesteps = 1000
self.step_index = None
self.noise_pred = None
self._step_index = None
self._begin_index = None
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == "BF16":
self.latents = self.latents.to(dtype=torch.bfloat16)
def set_timesteps(
self,
infer_steps: Union[int, None] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[Union[float, None]] = None,
shift: Optional[Union[float, None]] = None,
):
sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1]
if shift is None:
shift = self.shift
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
sigma_last = 0
timesteps = sigmas * self.num_train_timesteps
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
assert len(self.timesteps) == self.infer_steps
self.model_outputs = [
None,
] * self.solver_order
self.lower_order_nums = 0
self.last_sample = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu")
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = (
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
* self.init_noise_sigma
)
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
sample = sample.to(torch.float32)
sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
# x0 = sample - model_output * sigma
x_t_next = sample + (sigma_next - sigma) * model_output
self.latents = x_t_next
def reset(self):
self.model_outputs = [None]
self.timestep_list = [None]
self.last_sample = None
self.noise_pred = None
self.this_order = None
self.lower_order_nums = 0
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
gc.collect()
torch.cuda.empty_cache()
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
lora_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
......@@ -42,5 +44,4 @@ python -m lightx2v.infer \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/15.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/15.wav \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4 \
--lora_path ${lora_path}
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_audio.mp4
......@@ -5,12 +5,13 @@
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
###
import torch
from safetensors.torch import save_file
import sys
import os
from safetensors.torch import save_file
from safetensors.torch import load_file
if len(sys.argv) != 3:
print("用法: python convert_lora.py <输入文件.pt> <输出文件.safetensors>")
print("用法: python convert_lora.py <输入文件> <输出文件.safetensors>")
sys.exit(1)
ckpt_path = sys.argv[1]
......@@ -20,7 +21,10 @@ if not os.path.exists(ckpt_path):
print(f"❌ 输入文件不存在: {ckpt_path}")
sys.exit(1)
state_dict = torch.load(ckpt_path, map_location="cpu")
if ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment