Commit 98a85500 authored by sandy's avatar sandy Committed by GitHub
Browse files

[Refactor] save video (#316)

parent c2f2d263
import gc import gc
import os import os
import subprocess import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -12,6 +12,7 @@ import torchvision.transforms.functional as TF ...@@ -12,6 +12,7 @@ import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
...@@ -27,7 +28,9 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE ...@@ -27,7 +28,9 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, save_to_video, vae_to_comfyui_image from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size): def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
...@@ -664,34 +667,26 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -664,34 +667,26 @@ class WanAudioRunner(WanRunner): # type:ignore
self.run_input_encoder = self._run_input_encoder_local_r2v_audio self.run_input_encoder = self._run_input_encoder_local_r2v_audio
def _save_video_with_audio(self, images, audio_array, fps): def _save_video_with_audio(self, images, audio_array, fps):
"""Save video with audio""" output_path = self.config.get("save_video_path")
import tempfile parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_tmp: os.makedirs(parent_dir, exist_ok=True)
video_path = video_tmp.name
sample_rate = self._audio_processor.audio_sr
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_tmp:
audio_path = audio_tmp.name if images.dtype != torch.uint8:
images = (images * 255).clamp(0, 255).to(torch.uint8)
try:
save_to_video(images, video_path, fps) write_video(
ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr) # type: ignore filename=output_path,
video_array=images,
output_path = self.config.get("save_video_path") fps=fps,
parent_dir = os.path.dirname(output_path) video_codec="libx264",
if parent_dir and not os.path.exists(parent_dir): audio_array=torch.tensor(audio_array[None]),
os.makedirs(parent_dir, exist_ok=True) audio_fps=sample_rate,
audio_codec="aac",
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_path, "-i", audio_path, output_path]) options={"preset": "medium", "crf": "23"}, # 可调整视频输出质量
)
logger.info(f"Saved video with audio to: {output_path}")
finally:
# Clean up temp files
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(audio_path):
os.remove(audio_path)
def load_transformer(self): def load_transformer(self):
"""Load transformer with LoRA support""" """Load transformer with LoRA support"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment