Commit d048b178 authored by gaclove's avatar gaclove
Browse files

refactor: enhance WanAudioRunner to improve audio handling and frame interpolation

parent 5bd9bdbd
...@@ -4,6 +4,10 @@ import numpy as np ...@@ -4,6 +4,10 @@ import numpy as np
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from contextlib import contextmanager
from typing import Optional, Tuple, Union, List, Dict, Any
from dataclasses import dataclass
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
...@@ -34,33 +38,41 @@ from torchvision.transforms.functional import resize ...@@ -34,33 +38,41 @@ from torchvision.transforms.functional import resize
import subprocess import subprocess
import warnings import warnings
from typing import Optional, Tuple, Union
def add_mask_to_frames( @contextmanager
frames: np.ndarray, def memory_efficient_inference():
mask_rate: float = 0.1, """Context manager for memory-efficient inference"""
rnd_state: np.random.RandomState = None, try:
) -> np.ndarray: yield
if mask_rate is None: finally:
return frames if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if rnd_state is None:
rnd_state = np.random.RandomState()
h, w = frames.shape[-2:] @dataclass
mask = rnd_state.rand(h, w) > mask_rate class AudioSegment:
frames = frames * mask """Data class for audio segment information"""
return frames
audio_array: np.ndarray
start_frame: int
end_frame: int
is_last: bool = False
useful_length: Optional[int] = None
class FramePreprocessor:
"""Handles frame preprocessing including noise and masking"""
def __init__(self, noise_mean: float = -3.0, noise_std: float = 0.5, mask_rate: float = 0.1):
self.noise_mean = noise_mean
self.noise_std = noise_std
self.mask_rate = mask_rate
def add_noise_to_frames( def add_noise(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray:
frames: np.ndarray, """Add noise to frames"""
noise_mean: float = -3.0, if self.noise_mean is None or self.noise_std is None:
noise_std: float = 0.5,
rnd_state: np.random.RandomState = None,
) -> np.ndarray:
if noise_mean is None or noise_std is None:
return frames return frames
if rnd_state is None: if rnd_state is None:
...@@ -68,13 +80,225 @@ def add_noise_to_frames( ...@@ -68,13 +80,225 @@ def add_noise_to_frames(
shape = frames.shape shape = frames.shape
bs = 1 if len(shape) == 4 else shape[0] bs = 1 if len(shape) == 4 else shape[0]
sigma = rnd_state.normal(loc=noise_mean, scale=noise_std, size=(bs,)) sigma = rnd_state.normal(loc=self.noise_mean, scale=self.noise_std, size=(bs,))
sigma = np.exp(sigma) sigma = np.exp(sigma)
sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape)))) sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape))))
noise = rnd_state.randn(*shape) * sigma noise = rnd_state.randn(*shape) * sigma
frames = frames + noise return frames + noise
def add_mask(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray:
"""Add mask to frames"""
if self.mask_rate is None:
return frames return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
h, w = frames.shape[-2:]
mask = rnd_state.rand(h, w) > self.mask_rate
return frames * mask
def process_prev_frames(self, frames: torch.Tensor) -> torch.Tensor:
"""Process previous frames with noise and masking"""
frames_np = frames.cpu().detach().numpy()
frames_np = self.add_noise(frames_np)
frames_np = self.add_mask(frames_np)
return torch.from_numpy(frames_np).to(dtype=frames.dtype, device=frames.device)
class AudioProcessor:
"""Handles audio loading and segmentation"""
def __init__(self, audio_sr: int = 16000, target_fps: int = 16):
self.audio_sr = audio_sr
self.target_fps = target_fps
def load_audio(self, audio_path: str) -> np.ndarray:
"""Load and resample audio"""
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=self.audio_sr)
return audio_array.numpy()
def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]:
"""Calculate audio range for given frame range"""
audio_frame_rate = self.audio_sr / self.target_fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]:
"""Segment audio based on frame requirements"""
segments = []
# Calculate intervals
interval_num = 1
res_frame_num = 0
if expected_frames <= max_num_frames:
interval_num = 1
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
interval_num += 1
# Create segments
for idx in range(interval_num):
if idx == 0:
# First segment
audio_start, audio_end = self.get_audio_range(0, max_num_frames)
segment_audio = audio_array[audio_start:audio_end]
useful_length = None
if expected_frames < max_num_frames:
useful_length = segment_audio.shape[0]
max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)
segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length))
elif res_frame_num > 5 and idx == interval_num - 1:
# Last segment (might be shorter)
start_frame = idx * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, expected_frames)
segment_audio = audio_array[audio_start:audio_end]
useful_length = segment_audio.shape[0]
max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)
segments.append(AudioSegment(segment_audio, start_frame, expected_frames, True, useful_length))
else:
# Middle segments
start_frame = idx * max_num_frames - idx * prev_frame_length
end_frame = (idx + 1) * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, end_frame)
segment_audio = audio_array[audio_start:audio_end]
segments.append(AudioSegment(segment_audio, start_frame, end_frame, False))
return segments
class VideoGenerator:
"""Handles video generation for each segment"""
def __init__(self, model, vae_encoder, vae_decoder, config):
self.model = model
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.config = config
self.frame_preprocessor = FramePreprocessor()
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning"""
if prev_video is None:
return None
device = self.model.device
dtype = torch.bfloat16
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device)
# Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
# Create mask
prev_token_length = (prev_frame_length - 1) // 4 + 1
_, nframe, height, width = self.model.scheduler.latents.shape
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
return {"prev_latents": prev_latents, "prev_mask": prev_mask}
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
"""Rearrange mask for WAN model"""
if mask.ndim == 3:
mask = mask[None]
assert mask.ndim == 4
_, t, h, w = mask.shape
assert t == ((t - 1) // 4 * 4 + 1)
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1)
@torch.no_grad()
def generate_segment(self, inputs: Dict[str, Any], audio_features: torch.Tensor, prev_video: Optional[torch.Tensor] = None, prev_frame_length: int = 5, segment_idx: int = 0) -> torch.Tensor:
"""Generate video segment"""
# Update inputs with audio features
inputs["audio_encoder_output"] = audio_features
# Reset scheduler for non-first segments
if segment_idx > 0:
self.model.scheduler.reset()
# Prepare previous latents - ALWAYS needed, even for first segment
device = self.model.device
dtype = torch.bfloat16
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
max_num_frames = self.config.target_video_length
if segment_idx == 0:
# First segment - create zero frames
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
else:
# Subsequent segments - use previous video
previmg_encoder_output = self.prepare_prev_latents(prev_video, prev_frame_length)
if previmg_encoder_output:
prev_latents = previmg_encoder_output["prev_latents"]
prev_len = (prev_frame_length - 1) // 4 + 1
else:
# Fallback to zeros if prepare_prev_latents fails
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
# Create mask for prev_latents
_, nframe, height, width = self.model.scheduler.latents.shape
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
# Always set previmg_encoder_output
inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
# Run inference loop
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> Segment {segment_idx}, Step {step_index}/{self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
# Decode latents
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
return gen_video
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w): def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
tgt_ar = tgt_h / tgt_w tgt_ar = tgt_h / tgt_w
...@@ -131,267 +355,69 @@ def adaptive_resize(img): ...@@ -131,267 +355,69 @@ def adaptive_resize(img):
return cropped_img, target_h, target_w return cropped_img, target_h, target_w
def array_to_video( @RUNNER_REGISTER("wan2.1_audio")
image_array: np.ndarray, class WanAudioRunner(WanRunner):
output_path: str, def __init__(self, config):
fps: int | float = 30, super().__init__(config)
resolution: tuple[int, int] | tuple[float, float] | None = None, self._is_initialized = False
disable_log: bool = False, self._audio_adapter_pipe = None
lossless: bool = True, self._audio_processor = None
output_pix_fmt: str = "yuv420p", self._video_generator = None
) -> None: self._audio_preprocess = None
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
output_pix_fmt (str): output pix_fmt in ffmpeg command.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if not isinstance(image_array, np.ndarray):
raise TypeError("Input should be np.ndarray.")
assert image_array.ndim == 4
assert image_array.shape[-1] == 3
if resolution:
height, width = resolution
width += width % 2
height += height % 2
else:
image_array = pad_for_libx264(image_array)
height, width = image_array.shape[1], image_array.shape[2]
if lossless:
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264rgb",
"-crf",
"0",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
else:
output_pix_fmt = output_pix_fmt or "yuv420p"
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264",
"-pix_fmt",
f"{output_pix_fmt}",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
if output_pix_fmt is not None:
command += ["-pix_fmt", output_pix_fmt]
if not disable_log:
print(f'Running "{" ".join(command)}"')
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None or process.stderr is None:
raise BrokenPipeError("No buffer received.")
index = 0
while True:
if index >= image_array.shape[0]:
break
process.stdin.write(image_array[index].tobytes())
index += 1
process.stdin.close()
process.stderr.close()
process.wait()
def pad_for_libx264(image_array):
if image_array.ndim == 2 or (image_array.ndim == 3 and image_array.shape[2] == 3):
hei_index = 0
wid_index = 1
elif image_array.ndim == 4 or (image_array.ndim == 3 and image_array.shape[2] != 3):
hei_index = 1
wid_index = 2
else:
return image_array
hei_pad = image_array.shape[hei_index] % 2
wid_pad = image_array.shape[wid_index] % 2
if hei_pad + wid_pad > 0:
pad_width = []
for dim_index in range(image_array.ndim):
if dim_index == hei_index:
pad_width.append((0, hei_pad))
elif dim_index == wid_index:
pad_width.append((0, wid_pad))
else:
pad_width.append((0, 0))
values = 0
image_array = np.pad(image_array, pad_width, mode="constant", constant_values=values)
return image_array
def generate_unique_path(path):
if not os.path.exists(path):
return path
root, ext = os.path.splitext(path)
index = 1
new_path = f"{root}-{index}{ext}"
while os.path.exists(new_path):
index += 1
new_path = f"{root}-{index}{ext}"
return new_path
def save_audio(
audio_array,
audio_name: str,
video_name: str,
sr: int = 16000,
output_path: Optional[str] = None,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
ta.save(
audio_name,
torch.tensor(audio_array[None]),
sample_rate=sr,
)
if output_path is None:
out_video = f"{video_name[:-4]}_with_audio.mp4"
else:
out_video = output_path
parent_dir = os.path.dirname(out_video) def initialize_once(self):
if parent_dir and not os.path.exists(parent_dir): """Initialize all models once for multiple runs"""
os.makedirs(parent_dir, exist_ok=True) if self._is_initialized:
return
if os.path.exists(out_video): logger.info("Initializing models (one-time setup)...")
os.remove(out_video)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_name, "-i", audio_name, out_video]) # Initialize audio processor
audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
return out_video # Load audio feature extractor
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
# Initialize scheduler
self.init_scheduler()
@RUNNER_REGISTER("wan2.1_audio") self._is_initialized = True
class WanAudioRunner(WanRunner): logger.info("Model initialization complete")
def __init__(self, config):
super().__init__(config)
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config) scheduler = ConsistencyModelScheduler(self.config)
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def load_audio_models(self): def load_audio_adapter_lazy(self):
##音频特征提取器 """Lazy load audio adapter when needed"""
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder") if self._audio_adapter_pipe is not None:
return self._audio_adapter_pipe
##音频驱动视频生成adapter # Audio adapter
audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors" audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors"
audio_adaper = AudioAdapter.from_transformer( audio_adapter = AudioAdapter.from_transformer(
self.model, self.model,
audio_feature_dim=1024, audio_feature_dim=1024,
interval=1, interval=1,
time_freq_dim=256, time_freq_dim=256,
projection_transformer_layers=4, projection_transformer_layers=4,
) )
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, audio_adapter_path, strict=False) audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
##音频特征编码器 # Audio encoder
device = self.model.device device = self.model.device
audio_encoder_repo = self.config["model_path"] + "/audio_encoder" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0) self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
return audio_adapter_pipe
def load_transformer(self):
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return base_model
def load_image_encoder(self):
clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder
def run_image_encoder(self, config, vae_model):
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device)
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
# resize and crop image
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list): #
# list转tensor
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out return self._audio_adapter_pipe
def run_input_encoder_internal(self): def prepare_inputs(self):
"""Prepare inputs for the model"""
image_encoder_output = None image_encoder_output = None
if os.path.isfile(self.config.image_path): if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"): with ProfilingContext("Run Img Encoder"):
vae_encode_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder) vae_encode_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
...@@ -399,204 +425,98 @@ class WanAudioRunner(WanRunner): ...@@ -399,204 +425,98 @@ class WanAudioRunner(WanRunner):
"clip_encoder_out": clip_encoder_out, "clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out, "vae_encode_out": vae_encode_out,
} }
logger.info(f"clip_encoder_out:{clip_encoder_out.shape} vae_encode_out:{vae_encode_out.shape}")
with ProfilingContext("Run Text Encoder"): with ProfilingContext("Run Text Encoder"):
logger.info(f"Prompt: {self.config['prompt']}")
img = Image.open(self.config["image_path"]).convert("RGB") img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(self.config["prompt"], img) text_encoder_output = self.run_text_encoder(self.config["prompt"], img)
self.set_target_shape() self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect()
torch.cuda.empty_cache()
def set_target_shape(self):
ret = {}
num_channels_latents = 16
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
else:
error_msg = "t2v task is not supported in WanAudioRunner"
assert 1 == 0, error_msg
ret["target_shape"] = self.config.target_shape return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output, "audio_adapter_pipe": self.load_audio_adapter_lazy()}
return ret
def run(self, save_video=True): def run_pipeline(self, save_video=True):
def load_audio(in_path: str, sr: float = 16000): """Optimized pipeline with modular components"""
audio_array, ori_sr = ta.load(in_path) # Ensure models are initialized
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=sr) self.initialize_once()
return audio_array.numpy()
def get_audio_range(start_frame: int, end_frame: int, fps: float, audio_sr: float = 16000): # Initialize video generator if needed
audio_frame_rate = audio_sr / fps if self._video_generator is None:
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate) self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config)
def wan_mask_rearrange(mask: torch.Tensor): # Prepare inputs
# mask: 1, T, H, W, where 1 means the input mask is one-channel with memory_efficient_inference():
if mask.ndim == 3: if self.config["use_prompt_enhancer"]:
mask = mask[None] self.config["prompt_enhanced"] = self.post_prompt_enhancer()
assert mask.ndim == 4
_, t, h, w = mask.shape
assert t == ((t - 1) // 4 * 4 + 1)
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1) # 4, T // 4, H, W
self.inputs["audio_adapter_pipe"] = self.load_audio_models() self.inputs = self.prepare_inputs()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# process audio # Process audio
audio_sr = self.config.get("audio_sr", 16000) audio_array = self._audio_processor.load_audio(self.config["audio_path"])
max_num_frames = self.config.get("target_video_length", 81) # wan2.1一段最多81帧,5秒,16fps video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16) # 音视频同步帧率 target_fps = self.config.get("target_fps", 16)
video_duration = self.config.get("video_duration", 5) # 期望视频输出时长 max_num_frames = self.config.get("target_video_length", 81)
audio_array = load_audio(self.config["audio_path"], sr=audio_sr)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
prev_frame_length = 5
prev_token_length = (prev_frame_length - 1) // 4 + 1
max_num_audio_length = int((max_num_frames + 1) / target_fps * audio_sr)
interval_num = 1 audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
# expected_frames expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
expected_frames = min(max(1, int(float(video_duration) * target_fps)), audio_len)
res_frame_num = 0
if expected_frames <= max_num_frames:
interval_num = 1
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
interval_num += 1
audio_start, audio_end = get_audio_range(0, expected_frames, fps=target_fps, audio_sr=audio_sr) # Segment audio
audio_array_ori = audio_array[audio_start:audio_end] audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
# Generate video segments
gen_video_list = [] gen_video_list = []
cut_audio_list = [] cut_audio_list = []
# reference latents prev_video = None
tgt_h = self.config.tgt_h
tgt_w = self.config.tgt_w
device = self.model.scheduler.latents.device
dtype = torch.bfloat16
vae_dtype = torch.float
for idx in range(interval_num): for idx, segment in enumerate(audio_segments):
# Update seed for each segment
self.config.seed = self.config.seed + idx self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed) torch.manual_seed(self.config.seed)
logger.info(f"### manual_seed: {self.config.seed} ####") logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")
useful_length = -1
if idx == 0: # 第一段 Condition padding0 # Process audio features
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device) audio_features = self._audio_preprocess(segment.audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0 # Generate video segment
audio_start, audio_end = get_audio_range(0, max_num_frames, fps=target_fps, audio_sr=audio_sr) with memory_efficient_inference():
audio_array = audio_array_ori[audio_start:audio_end] gen_video = self._video_generator.generate_segment(
if expected_frames < max_num_frames: self.inputs.copy(), # Copy to avoid modifying original
useful_length = audio_array.shape[0] audio_features,
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0) prev_video=prev_video,
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0) prev_frame_length=5,
segment_idx=idx,
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧 )
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames) # mean:-3.0 std:0.5
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
self.inputs["audio_encoder_output"] = audio_input_feat.to(device)
if idx != 0:
self.model.scheduler.reset()
if prev_latents is not None:
_, nframe, height, width = self.model.scheduler.latents.shape
# bs = 1
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = wan_mask_rearrange(prev_mask).unsqueeze(0)
previmg_encoder_output = {
"prev_latents": prev_latents,
"prev_mask": prev_mask,
}
self.inputs["previmg_encoder_output"] = previmg_encoder_output
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"): # Extract relevant frames
self.model.scheduler.step_post() start_frame = 0 if idx == 0 else 5
start_audio_frame = 0 if idx == 0 else int(6 * self._audio_processor.audio_sr / target_fps)
latents = self.model.scheduler.latents if segment.is_last and segment.useful_length:
generator = self.model.scheduler.generator end_frame = segment.end_frame - segment.start_frame
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config) gen_video_list.append(gen_video[:, :, start_frame:end_frame].cpu())
gen_video = torch.clamp(gen_video, -1, 1).to(torch.float) cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
start_frame = 0 if idx == 0 else prev_frame_length elif segment.useful_length and expected_frames < max_num_frames:
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
elif expected_frames < max_num_frames and useful_length != -1:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu()) gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length]) cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
else: else:
gen_video_list.append(gen_video[:, :, start_frame:].cpu()) gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(audio_array[start_audio_frame:]) cut_audio_list.append(segment.audio_array[start_audio_frame:])
# Update prev_video for next iteration
prev_video = gen_video
# Clean up GPU memory after each segment
del gen_video
torch.cuda.empty_cache()
# Merge results
with memory_efficient_inference():
gen_lvideo = torch.cat(gen_video_list, dim=2).float() gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32) merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo) comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured # Apply frame interpolation if configured
if "video_frame_interpolation" in self.config: if "video_frame_interpolation" in self.config and self.vfi_model is not None:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"] interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}") logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames( comfyui_images = self.vfi_model.interpolate_frames(
...@@ -604,39 +524,118 @@ class WanAudioRunner(WanRunner): ...@@ -604,39 +524,118 @@ class WanAudioRunner(WanRunner):
source_fps=target_fps, source_fps=target_fps,
target_fps=interpolation_target_fps, target_fps=interpolation_target_fps,
) )
# Update target_fps for saving
target_fps = interpolation_target_fps target_fps = interpolation_target_fps
# Convert audio to ComfyUI format # Convert audio to ComfyUI format
# Convert numpy array to torch tensor and add batch dimension audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0) # [batch, channels, samples] comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
comfyui_audio = {"waveform": audio_waveform, "sample_rate": audio_sr}
# Save video if requested # Save video if requested
if save_video and self.config.get("save_video_path", None): if save_video and self.config.get("save_video_path", None):
out_path = os.path.join("./", "video_merge.mp4") self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
audio_file = os.path.join("./", "audio_merge.wav")
# Use the updated target_fps (after interpolation if applied) # Final cleanup
save_to_video(comfyui_images, out_path, target_fps) self.end_run()
save_audio(merge_audio, audio_file, out_path, output_path=self.config.get("save_video_path", None))
os.remove(out_path)
os.remove(audio_file)
return comfyui_images, comfyui_audio return comfyui_images, comfyui_audio
def run_pipeline(self, save_video=True): def _save_video_with_audio(self, images, audio_array, fps):
if self.config["use_prompt_enhancer"]: """Save video with audio"""
self.config["prompt_enhanced"] = self.post_prompt_enhancer() import tempfile
self.run_input_encoder_internal() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_tmp:
self.set_target_shape() video_path = video_tmp.name
self.init_scheduler() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_tmp:
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) audio_path = audio_tmp.name
images, audio = self.run(save_video) # run() now returns both images and audio
self.end_run()
gc.collect() try:
torch.cuda.empty_cache() # Save video
save_to_video(images, video_path, fps)
# Save audio
ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr)
# Merge video and audio
output_path = self.config.get("save_video_path")
parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_path, "-i", audio_path, output_path])
logger.info(f"Saved video with audio to: {output_path}")
finally:
# Clean up temp files
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(audio_path):
os.remove(audio_path)
def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.get("lora_configs") and self.config.lora_configs:
assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
lora_wrapper = WanLoraWrapper(base_model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return base_model
return images, audio def load_image_encoder(self):
"""Load image encoder"""
clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder
def run_image_encoder(self, config, vae_model):
"""Run image encoder"""
ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device)
ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3]
# Resize and crop image
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h
config.tgt_w = tgt_w
clip_encoder_out = self.image_encoder.encode(cond_frms).squeeze(0).to(torch.bfloat16)
cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
lat_h, lat_w = tgt_h // 8, tgt_w // 8
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list):
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out
def set_target_shape(self):
"""Set target shape for generation"""
ret = {}
num_channels_latents = 16
if self.config.task == "i2v":
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
self.config.lat_h,
self.config.lat_w,
)
ret["lat_h"] = self.config.lat_h
ret["lat_w"] = self.config.lat_w
else:
error_msg = "t2v task is not supported in WanAudioRunner"
assert False, error_msg
ret["target_shape"] = self.config.target_shape
return ret
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment