Commit d048b178 authored by gaclove's avatar gaclove
Browse files

refactor: enhance WanAudioRunner to improve audio handling and frame interpolation

parent 5bd9bdbd
...@@ -4,6 +4,10 @@ import numpy as np ...@@ -4,6 +4,10 @@ import numpy as np
import torch import torch
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from contextlib import contextmanager
from typing import Optional, Tuple, Union, List, Dict, Any
from dataclasses import dataclass
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.runners.default_runner import DefaultRunner
...@@ -34,46 +38,266 @@ from torchvision.transforms.functional import resize ...@@ -34,46 +38,266 @@ from torchvision.transforms.functional import resize
import subprocess import subprocess
import warnings import warnings
from typing import Optional, Tuple, Union
def add_mask_to_frames( @contextmanager
frames: np.ndarray, def memory_efficient_inference():
mask_rate: float = 0.1, """Context manager for memory-efficient inference"""
rnd_state: np.random.RandomState = None, try:
) -> np.ndarray: yield
if mask_rate is None: finally:
return frames if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@dataclass
class AudioSegment:
"""Data class for audio segment information"""
audio_array: np.ndarray
start_frame: int
end_frame: int
is_last: bool = False
useful_length: Optional[int] = None
class FramePreprocessor:
"""Handles frame preprocessing including noise and masking"""
def __init__(self, noise_mean: float = -3.0, noise_std: float = 0.5, mask_rate: float = 0.1):
self.noise_mean = noise_mean
self.noise_std = noise_std
self.mask_rate = mask_rate
def add_noise(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray:
"""Add noise to frames"""
if self.noise_mean is None or self.noise_std is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
shape = frames.shape
bs = 1 if len(shape) == 4 else shape[0]
sigma = rnd_state.normal(loc=self.noise_mean, scale=self.noise_std, size=(bs,))
sigma = np.exp(sigma)
sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape))))
noise = rnd_state.randn(*shape) * sigma
return frames + noise
def add_mask(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray:
"""Add mask to frames"""
if self.mask_rate is None:
return frames
if rnd_state is None:
rnd_state = np.random.RandomState()
h, w = frames.shape[-2:]
mask = rnd_state.rand(h, w) > self.mask_rate
return frames * mask
def process_prev_frames(self, frames: torch.Tensor) -> torch.Tensor:
"""Process previous frames with noise and masking"""
frames_np = frames.cpu().detach().numpy()
frames_np = self.add_noise(frames_np)
frames_np = self.add_mask(frames_np)
return torch.from_numpy(frames_np).to(dtype=frames.dtype, device=frames.device)
class AudioProcessor:
"""Handles audio loading and segmentation"""
def __init__(self, audio_sr: int = 16000, target_fps: int = 16):
self.audio_sr = audio_sr
self.target_fps = target_fps
def load_audio(self, audio_path: str) -> np.ndarray:
"""Load and resample audio"""
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=self.audio_sr)
return audio_array.numpy()
def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]:
"""Calculate audio range for given frame range"""
audio_frame_rate = self.audio_sr / self.target_fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]:
"""Segment audio based on frame requirements"""
segments = []
# Calculate intervals
interval_num = 1
res_frame_num = 0
if expected_frames <= max_num_frames:
interval_num = 1
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
interval_num += 1
# Create segments
for idx in range(interval_num):
if idx == 0:
# First segment
audio_start, audio_end = self.get_audio_range(0, max_num_frames)
segment_audio = audio_array[audio_start:audio_end]
useful_length = None
if expected_frames < max_num_frames:
useful_length = segment_audio.shape[0]
max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)
segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length))
elif res_frame_num > 5 and idx == interval_num - 1:
# Last segment (might be shorter)
start_frame = idx * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, expected_frames)
segment_audio = audio_array[audio_start:audio_end]
useful_length = segment_audio.shape[0]
max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)
segments.append(AudioSegment(segment_audio, start_frame, expected_frames, True, useful_length))
else:
# Middle segments
start_frame = idx * max_num_frames - idx * prev_frame_length
end_frame = (idx + 1) * max_num_frames - idx * prev_frame_length
audio_start, audio_end = self.get_audio_range(start_frame, end_frame)
segment_audio = audio_array[audio_start:audio_end]
segments.append(AudioSegment(segment_audio, start_frame, end_frame, False))
return segments
class VideoGenerator:
"""Handles video generation for each segment"""
def __init__(self, model, vae_encoder, vae_decoder, config):
self.model = model
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.config = config
self.frame_preprocessor = FramePreprocessor()
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning"""
if prev_video is None:
return None
device = self.model.device
dtype = torch.bfloat16
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device)
# Extract and process last frames
last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
# Create mask
prev_token_length = (prev_frame_length - 1) // 4 + 1
_, nframe, height, width = self.model.scheduler.latents.shape
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
return {"prev_latents": prev_latents, "prev_mask": prev_mask}
def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
"""Rearrange mask for WAN model"""
if mask.ndim == 3:
mask = mask[None]
assert mask.ndim == 4
_, t, h, w = mask.shape
assert t == ((t - 1) // 4 * 4 + 1)
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1)
@torch.no_grad()
def generate_segment(self, inputs: Dict[str, Any], audio_features: torch.Tensor, prev_video: Optional[torch.Tensor] = None, prev_frame_length: int = 5, segment_idx: int = 0) -> torch.Tensor:
"""Generate video segment"""
# Update inputs with audio features
inputs["audio_encoder_output"] = audio_features
# Reset scheduler for non-first segments
if segment_idx > 0:
self.model.scheduler.reset()
# Prepare previous latents - ALWAYS needed, even for first segment
device = self.model.device
dtype = torch.bfloat16
vae_dtype = torch.float
tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
max_num_frames = self.config.target_video_length
if segment_idx == 0:
# First segment - create zero frames
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
else:
# Subsequent segments - use previous video
previmg_encoder_output = self.prepare_prev_latents(prev_video, prev_frame_length)
if previmg_encoder_output:
prev_latents = previmg_encoder_output["prev_latents"]
prev_len = (prev_frame_length - 1) // 4 + 1
else:
# Fallback to zeros if prepare_prev_latents fails
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
# Create mask for prev_latents
_, nframe, height, width = self.model.scheduler.latents.shape
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
if rnd_state is None: # Always set previmg_encoder_output
rnd_state = np.random.RandomState() inputs["previmg_encoder_output"] = {"prev_latents": prev_latents, "prev_mask": prev_mask}
h, w = frames.shape[-2:] # Run inference loop
mask = rnd_state.rand(h, w) > mask_rate for step_index in range(self.model.scheduler.infer_steps):
frames = frames * mask logger.info(f"==> Segment {segment_idx}, Step {step_index}/{self.model.scheduler.infer_steps}")
return frames
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
def add_noise_to_frames( with ProfilingContext4Debug("infer"):
frames: np.ndarray, self.model.infer(inputs)
noise_mean: float = -3.0,
noise_std: float = 0.5,
rnd_state: np.random.RandomState = None,
) -> np.ndarray:
if noise_mean is None or noise_std is None:
return frames
if rnd_state is None: with ProfilingContext4Debug("step_post"):
rnd_state = np.random.RandomState() self.model.scheduler.step_post()
shape = frames.shape # Decode latents
bs = 1 if len(shape) == 4 else shape[0] latents = self.model.scheduler.latents
sigma = rnd_state.normal(loc=noise_mean, scale=noise_std, size=(bs,)) generator = self.model.scheduler.generator
sigma = np.exp(sigma) gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape)))) gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
noise = rnd_state.randn(*shape) * sigma
frames = frames + noise return gen_video
return frames
def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w): def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
...@@ -131,221 +355,226 @@ def adaptive_resize(img): ...@@ -131,221 +355,226 @@ def adaptive_resize(img):
return cropped_img, target_h, target_w return cropped_img, target_h, target_w
def array_to_video( @RUNNER_REGISTER("wan2.1_audio")
image_array: np.ndarray, class WanAudioRunner(WanRunner):
output_path: str, def __init__(self, config):
fps: int | float = 30, super().__init__(config)
resolution: tuple[int, int] | tuple[float, float] | None = None, self._is_initialized = False
disable_log: bool = False, self._audio_adapter_pipe = None
lossless: bool = True, self._audio_processor = None
output_pix_fmt: str = "yuv420p", self._video_generator = None
) -> None: self._audio_preprocess = None
"""Convert an array to a video directly, gif not supported.
Args:
image_array (np.ndarray): shape should be (f * h * w * 3).
output_path (str): output video file path.
fps (Union[int, float, optional): fps. Defaults to 30.
resolution (Optional[Union[Tuple[int, int], Tuple[float, float]]],
optional): (height, width) of the output video.
Defaults to None.
disable_log (bool, optional): whether close the ffmepg command info.
Defaults to False.
output_pix_fmt (str): output pix_fmt in ffmpeg command.
Raises:
FileNotFoundError: check output path.
TypeError: check input array.
Returns:
None.
"""
if not isinstance(image_array, np.ndarray):
raise TypeError("Input should be np.ndarray.")
assert image_array.ndim == 4
assert image_array.shape[-1] == 3
if resolution:
height, width = resolution
width += width % 2
height += height % 2
else:
image_array = pad_for_libx264(image_array)
height, width = image_array.shape[1], image_array.shape[2]
if lossless:
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264rgb",
"-crf",
"0",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
else:
output_pix_fmt = output_pix_fmt or "yuv420p"
command = [
"/usr/bin/ffmpeg",
"-y", # (optional) overwrite output file if it exists
"-f",
"rawvideo",
"-s",
f"{int(width)}x{int(height)}", # size of one frame
"-pix_fmt",
"bgr24",
"-r",
f"{fps}", # frames per second
"-loglevel",
"error",
"-threads",
"4",
"-i",
"-", # The input comes from a pipe
"-vcodec",
"libx264",
"-pix_fmt",
f"{output_pix_fmt}",
"-an", # Tells FFMPEG not to expect any audio
output_path,
]
if output_pix_fmt is not None:
command += ["-pix_fmt", output_pix_fmt]
if not disable_log:
print(f'Running "{" ".join(command)}"')
process = subprocess.Popen(
command,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if process.stdin is None or process.stderr is None:
raise BrokenPipeError("No buffer received.")
index = 0
while True:
if index >= image_array.shape[0]:
break
process.stdin.write(image_array[index].tobytes())
index += 1
process.stdin.close()
process.stderr.close()
process.wait()
def pad_for_libx264(image_array):
if image_array.ndim == 2 or (image_array.ndim == 3 and image_array.shape[2] == 3):
hei_index = 0
wid_index = 1
elif image_array.ndim == 4 or (image_array.ndim == 3 and image_array.shape[2] != 3):
hei_index = 1
wid_index = 2
else:
return image_array
hei_pad = image_array.shape[hei_index] % 2
wid_pad = image_array.shape[wid_index] % 2
if hei_pad + wid_pad > 0:
pad_width = []
for dim_index in range(image_array.ndim):
if dim_index == hei_index:
pad_width.append((0, hei_pad))
elif dim_index == wid_index:
pad_width.append((0, wid_pad))
else:
pad_width.append((0, 0))
values = 0
image_array = np.pad(image_array, pad_width, mode="constant", constant_values=values)
return image_array
def generate_unique_path(path):
if not os.path.exists(path):
return path
root, ext = os.path.splitext(path)
index = 1
new_path = f"{root}-{index}{ext}"
while os.path.exists(new_path):
index += 1
new_path = f"{root}-{index}{ext}"
return new_path
def save_audio(
audio_array,
audio_name: str,
video_name: str,
sr: int = 16000,
output_path: Optional[str] = None,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
ta.save(
audio_name,
torch.tensor(audio_array[None]),
sample_rate=sr,
)
if output_path is None:
out_video = f"{video_name[:-4]}_with_audio.mp4"
else:
out_video = output_path
parent_dir = os.path.dirname(out_video) def initialize_once(self):
if parent_dir and not os.path.exists(parent_dir): """Initialize all models once for multiple runs"""
os.makedirs(parent_dir, exist_ok=True) if self._is_initialized:
return
if os.path.exists(out_video): logger.info("Initializing models (one-time setup)...")
os.remove(out_video)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_name, "-i", audio_name, out_video]) # Initialize audio processor
audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps)
return out_video # Load audio feature extractor
self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")
# Initialize scheduler
self.init_scheduler()
@RUNNER_REGISTER("wan2.1_audio") self._is_initialized = True
class WanAudioRunner(WanRunner): logger.info("Model initialization complete")
def __init__(self, config):
super().__init__(config)
def init_scheduler(self): def init_scheduler(self):
"""Initialize consistency model scheduler"""
scheduler = ConsistencyModelScheduler(self.config) scheduler = ConsistencyModelScheduler(self.config)
self.model.set_scheduler(scheduler) self.model.set_scheduler(scheduler)
def load_audio_models(self): def load_audio_adapter_lazy(self):
##音频特征提取器 """Lazy load audio adapter when needed"""
self.audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder") if self._audio_adapter_pipe is not None:
return self._audio_adapter_pipe
##音频驱动视频生成adapter # Audio adapter
audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors" audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors"
audio_adaper = AudioAdapter.from_transformer( audio_adapter = AudioAdapter.from_transformer(
self.model, self.model,
audio_feature_dim=1024, audio_feature_dim=1024,
interval=1, interval=1,
time_freq_dim=256, time_freq_dim=256,
projection_transformer_layers=4, projection_transformer_layers=4,
) )
audio_adapter = rank0_load_state_dict_from_path(audio_adaper, audio_adapter_path, strict=False) audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
##音频特征编码器 # Audio encoder
device = self.model.device device = self.model.device
audio_encoder_repo = self.config["model_path"] + "/audio_encoder" audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0) self._audio_adapter_pipe = AudioAdapterPipe(audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=torch.bfloat16, device=device, generator=torch.Generator(device), weight=1.0)
return audio_adapter_pipe return self._audio_adapter_pipe
def prepare_inputs(self):
"""Prepare inputs for the model"""
image_encoder_output = None
if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"):
vae_encode_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
with ProfilingContext("Run Text Encoder"):
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(self.config["prompt"], img)
self.set_target_shape()
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output, "audio_adapter_pipe": self.load_audio_adapter_lazy()}
def run_pipeline(self, save_video=True):
"""Optimized pipeline with modular components"""
# Ensure models are initialized
self.initialize_once()
# Initialize video generator if needed
if self._video_generator is None:
self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config)
# Prepare inputs
with memory_efficient_inference():
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.prepare_inputs()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
# Process audio
audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5)
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)
# Segment audio
audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)
# Generate video segments
gen_video_list = []
cut_audio_list = []
prev_video = None
for idx, segment in enumerate(audio_segments):
# Update seed for each segment
self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed)
logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")
# Process audio features
audio_features = self._audio_preprocess(segment.audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)
# Generate video segment
with memory_efficient_inference():
gen_video = self._video_generator.generate_segment(
self.inputs.copy(), # Copy to avoid modifying original
audio_features,
prev_video=prev_video,
prev_frame_length=5,
segment_idx=idx,
)
# Extract relevant frames
start_frame = 0 if idx == 0 else 5
start_audio_frame = 0 if idx == 0 else int(6 * self._audio_processor.audio_sr / target_fps)
if segment.is_last and segment.useful_length:
end_frame = segment.end_frame - segment.start_frame
gen_video_list.append(gen_video[:, :, start_frame:end_frame].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
elif segment.useful_length and expected_frames < max_num_frames:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(segment.audio_array[start_audio_frame:])
# Update prev_video for next iteration
prev_video = gen_video
# Clean up GPU memory after each segment
del gen_video
torch.cuda.empty_cache()
# Merge results
with memory_efficient_inference():
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=target_fps,
target_fps=interpolation_target_fps,
)
target_fps = interpolation_target_fps
# Convert audio to ComfyUI format
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
# Save video if requested
if save_video and self.config.get("save_video_path", None):
self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
# Final cleanup
self.end_run()
return comfyui_images, comfyui_audio
def _save_video_with_audio(self, images, audio_array, fps):
"""Save video with audio"""
import tempfile
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_tmp:
video_path = video_tmp.name
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_tmp:
audio_path = audio_tmp.name
try:
# Save video
save_to_video(images, video_path, fps)
# Save audio
ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr)
# Merge video and audio
output_path = self.config.get("save_video_path")
parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_path, "-i", audio_path, output_path])
logger.info(f"Saved video with audio to: {output_path}")
finally:
# Clean up temp files
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(audio_path):
os.remove(audio_path)
def load_transformer(self): def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device) base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
if self.config.get("lora_configs") and self.config.lora_configs: if self.config.get("lora_configs") and self.config.lora_configs:
...@@ -361,19 +590,21 @@ class WanAudioRunner(WanRunner): ...@@ -361,19 +590,21 @@ class WanAudioRunner(WanRunner):
return base_model return base_model
def load_image_encoder(self): def load_image_encoder(self):
"""Load image encoder"""
clip_model_dir = self.config["model_path"] + "/image_encoder" clip_model_dir = self.config["model_path"] + "/image_encoder"
image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16) image_encoder = WanVideoIPHandler("CLIPModel", repo_or_path=clip_model_dir, require_grad=False, mode="eval", device=self.init_device, dtype=torch.float16)
return image_encoder return image_encoder
def run_image_encoder(self, config, vae_model): def run_image_encoder(self, config, vae_model):
"""Run image encoder"""
ref_img = Image.open(config.image_path) ref_img = Image.open(config.image_path)
ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5 ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
ref_img = torch.from_numpy(ref_img).to(vae_model.device) ref_img = torch.from_numpy(ref_img).to(vae_model.device)
ref_img = rearrange(ref_img, "H W C -> 1 C H W") ref_img = rearrange(ref_img, "H W C -> 1 C H W")
ref_img = ref_img[:, :3] ref_img = ref_img[:, :3]
# resize and crop image # Resize and crop image
cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img) cond_frms, tgt_h, tgt_w = adaptive_resize(ref_img)
config.tgt_h = tgt_h config.tgt_h = tgt_h
config.tgt_w = tgt_w config.tgt_w = tgt_w
...@@ -384,36 +615,13 @@ class WanAudioRunner(WanRunner): ...@@ -384,36 +615,13 @@ class WanAudioRunner(WanRunner):
config.lat_h = lat_h config.lat_h = lat_h
config.lat_w = lat_w config.lat_w = lat_w
vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config) vae_encode_out = vae_model.encode(cond_frms.to(torch.float), config)
if isinstance(vae_encode_out, list): # if isinstance(vae_encode_out, list):
# list转tensor
vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16) vae_encode_out = torch.stack(vae_encode_out, dim=0).to(torch.bfloat16)
return vae_encode_out, clip_encoder_out return vae_encode_out, clip_encoder_out
def run_input_encoder_internal(self):
image_encoder_output = None
if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"):
vae_encode_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encode_out": vae_encode_out,
}
logger.info(f"clip_encoder_out:{clip_encoder_out.shape} vae_encode_out:{vae_encode_out.shape}")
with ProfilingContext("Run Text Encoder"):
logger.info(f"Prompt: {self.config['prompt']}")
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(self.config["prompt"], img)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect()
torch.cuda.empty_cache()
def set_target_shape(self): def set_target_shape(self):
"""Set target shape for generation"""
ret = {} ret = {}
num_channels_latents = 16 num_channels_latents = 16
if self.config.task == "i2v": if self.config.task == "i2v":
...@@ -427,216 +635,7 @@ class WanAudioRunner(WanRunner): ...@@ -427,216 +635,7 @@ class WanAudioRunner(WanRunner):
ret["lat_w"] = self.config.lat_w ret["lat_w"] = self.config.lat_w
else: else:
error_msg = "t2v task is not supported in WanAudioRunner" error_msg = "t2v task is not supported in WanAudioRunner"
assert 1 == 0, error_msg assert False, error_msg
ret["target_shape"] = self.config.target_shape ret["target_shape"] = self.config.target_shape
return ret return ret
def run(self, save_video=True):
def load_audio(in_path: str, sr: float = 16000):
audio_array, ori_sr = ta.load(in_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=sr)
return audio_array.numpy()
def get_audio_range(start_frame: int, end_frame: int, fps: float, audio_sr: float = 16000):
audio_frame_rate = audio_sr / fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)
def wan_mask_rearrange(mask: torch.Tensor):
# mask: 1, T, H, W, where 1 means the input mask is one-channel
if mask.ndim == 3:
mask = mask[None]
assert mask.ndim == 4
_, t, h, w = mask.shape
assert t == ((t - 1) // 4 * 4 + 1)
mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
mask = mask.view(mask.shape[1] // 4, 4, h, w)
return mask.transpose(0, 1) # 4, T // 4, H, W
self.inputs["audio_adapter_pipe"] = self.load_audio_models()
# process audio
audio_sr = self.config.get("audio_sr", 16000)
max_num_frames = self.config.get("target_video_length", 81) # wan2.1一段最多81帧,5秒,16fps
target_fps = self.config.get("target_fps", 16) # 音视频同步帧率
video_duration = self.config.get("video_duration", 5) # 期望视频输出时长
audio_array = load_audio(self.config["audio_path"], sr=audio_sr)
audio_len = int(audio_array.shape[0] / audio_sr * target_fps)
prev_frame_length = 5
prev_token_length = (prev_frame_length - 1) // 4 + 1
max_num_audio_length = int((max_num_frames + 1) / target_fps * audio_sr)
interval_num = 1
# expected_frames
expected_frames = min(max(1, int(float(video_duration) * target_fps)), audio_len)
res_frame_num = 0
if expected_frames <= max_num_frames:
interval_num = 1
else:
interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
if res_frame_num > 5:
interval_num += 1
audio_start, audio_end = get_audio_range(0, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array_ori = audio_array[audio_start:audio_end]
gen_video_list = []
cut_audio_list = []
# reference latents
tgt_h = self.config.tgt_h
tgt_w = self.config.tgt_w
device = self.model.scheduler.latents.device
dtype = torch.bfloat16
vae_dtype = torch.float
for idx in range(interval_num):
self.config.seed = self.config.seed + idx
torch.manual_seed(self.config.seed)
logger.info(f"### manual_seed: {self.config.seed} ####")
useful_length = -1
if idx == 0: # 第一段 Condition padding0
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = 0
audio_start, audio_end = get_audio_range(0, max_num_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
if expected_frames < max_num_frames:
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
elif res_frame_num > 5 and idx == interval_num - 1: # 最后一段可能不够81帧
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames)
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, expected_frames, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
useful_length = audio_array.shape[0]
audio_array = np.concatenate((audio_array, np.zeros(max_num_audio_length)[: max_num_audio_length - useful_length]), axis=0)
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
else: # 中间段满81帧带pre_latens
prev_frames = torch.zeros((1, 3, max_num_frames, tgt_h, tgt_w), device=device)
last_frames = gen_video_list[-1][:, :, -prev_frame_length:].clone().to(device)
last_frames = last_frames.cpu().detach().numpy()
last_frames = add_noise_to_frames(last_frames) # mean:-3.0 std:0.5
last_frames = add_mask_to_frames(last_frames, mask_rate=0.1) # mask 0.10
last_frames = torch.from_numpy(last_frames).to(dtype=dtype, device=device)
prev_frames[:, :, :prev_frame_length] = last_frames
prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)
prev_len = prev_token_length
audio_start, audio_end = get_audio_range(idx * max_num_frames - idx * prev_frame_length, (idx + 1) * max_num_frames - idx * prev_frame_length, fps=target_fps, audio_sr=audio_sr)
audio_array = audio_array_ori[audio_start:audio_end]
audio_input_feat = self.audio_preprocess(audio_array, sampling_rate=audio_sr, return_tensors="pt").input_values.squeeze(0)
self.inputs["audio_encoder_output"] = audio_input_feat.to(device)
if idx != 0:
self.model.scheduler.reset()
if prev_latents is not None:
_, nframe, height, width = self.model.scheduler.latents.shape
# bs = 1
frames_n = (nframe - 1) * 4 + 1
prev_frame_len = max((prev_len - 1) * 4 + 1, 0)
prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
prev_mask[:, prev_frame_len:] = 0
prev_mask = wan_mask_rearrange(prev_mask).unsqueeze(0)
previmg_encoder_output = {
"prev_latents": prev_latents,
"prev_mask": prev_mask,
}
self.inputs["previmg_encoder_output"] = previmg_encoder_output
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
latents = self.model.scheduler.latents
generator = self.model.scheduler.generator
gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)
start_frame = 0 if idx == 0 else prev_frame_length
start_audio_frame = 0 if idx == 0 else int((prev_frame_length + 1) * audio_sr / target_fps)
if res_frame_num > 5 and idx == interval_num - 1:
gen_video_list.append(gen_video[:, :, start_frame:res_frame_num].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
elif expected_frames < max_num_frames and useful_length != -1:
gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
cut_audio_list.append(audio_array[start_audio_frame:useful_length])
else:
gen_video_list.append(gen_video[:, :, start_frame:].cpu())
cut_audio_list.append(audio_array[start_audio_frame:])
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=target_fps,
target_fps=interpolation_target_fps,
)
# Update target_fps for saving
target_fps = interpolation_target_fps
# Convert audio to ComfyUI format
# Convert numpy array to torch tensor and add batch dimension
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0) # [batch, channels, samples]
comfyui_audio = {"waveform": audio_waveform, "sample_rate": audio_sr}
# Save video if requested
if save_video and self.config.get("save_video_path", None):
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")
# Use the updated target_fps (after interpolation if applied)
save_to_video(comfyui_images, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path, output_path=self.config.get("save_video_path", None))
os.remove(out_path)
os.remove(audio_file)
return comfyui_images, comfyui_audio
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.run_input_encoder_internal()
self.set_target_shape()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
images, audio = self.run(save_video) # run() now returns both images and audio
self.end_run()
gc.collect()
torch.cuda.empty_cache()
return images, audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment