Commit 98a85500 authored by sandy's avatar sandy Committed by GitHub
Browse files

[Refactor] save video (#316)

parent c2f2d263
import gc
import os
import subprocess
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
......@@ -12,6 +12,7 @@ import torchvision.transforms.functional as TF
from PIL import Image
from einops import rearrange
from loguru import logger
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
......@@ -27,7 +28,9 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image
from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
......@@ -664,34 +667,26 @@ class WanAudioRunner(WanRunner): # type:ignore
self.run_input_encoder = self._run_input_encoder_local_r2v_audio
def _save_video_with_audio(self, images, audio_array, fps):
"""Save video with audio"""
import tempfile
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_tmp:
video_path = video_tmp.name
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_tmp:
audio_path = audio_tmp.name
try:
save_to_video(images, video_path, fps)
ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr) # type: ignore
output_path = self.config.get("save_video_path")
parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_path, "-i", audio_path, output_path])
sample_rate = self._audio_processor.audio_sr
logger.info(f"Saved video with audio to: {output_path}")
if images.dtype != torch.uint8:
images = (images * 255).clamp(0, 255).to(torch.uint8)
finally:
# Clean up temp files
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(audio_path):
os.remove(audio_path)
write_video(
filename=output_path,
video_array=images,
fps=fps,
video_codec="libx264",
audio_array=torch.tensor(audio_array[None]),
audio_fps=sample_rate,
audio_codec="aac",
options={"preset": "medium", "crf": "23"}, # 可调整视频输出质量
)
def load_transformer(self):
"""Load transformer with LoRA support"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment