Unverified Commit 1b00af38 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add option to write audio to video file (#2304)



* Add option to write audio to video file

Summary:
I was trying to use torchvision's `write_video` function and realized there was no option to add in the audio.

Thus, this diff contains the changes necessary such that this is possible. This is my first time trying to contribute to this project, so be as harsh as you need!

Reviewed By: fmassa

Differential Revision: D21480083

fbshipit-source-id: 2e11f2c8728d42f86c94068f75b843793d5a94aa

* Fix typo

* Try fix Windows

* Disable test on Windows
Co-authored-by: default avatarJoanna Bitton <jbitton@fb.com>
parent 4eb9f660
import os import os
import contextlib import contextlib
import sys
import tempfile import tempfile
import torch import torch
import torchvision.datasets.utils as utils import torchvision.datasets.utils as utils
...@@ -20,6 +21,9 @@ except ImportError: ...@@ -20,6 +21,9 @@ except ImportError:
av = None av = None
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
def _create_video_frames(num_frames, height, width): def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = [] data = []
...@@ -260,6 +264,39 @@ class TestIO(unittest.TestCase): ...@@ -260,6 +264,39 @@ class TestIO(unittest.TestCase):
# and the last few frames are wrong # and the last few frames are wrong
self.assertFalse(video.equal(data)) self.assertFalse(video.equal(data))
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
def test_write_video_with_audio(self):
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")
with get_tmp_dir() as tmpdir:
out_f_name = os.path.join(tmpdir, "testing.mp4")
io.video.write_video(
out_f_name,
video_tensor,
round(info["video_fps"]),
video_codec="libx264rgb",
options={'crf': '0'},
audio_array=audio_tensor,
audio_fps=info["audio_fps"],
audio_codec="aac",
)
out_video_tensor, out_audio_tensor, out_info = io.read_video(
out_f_name, pts_unit="sec"
)
self.assertEqual(info["video_fps"], out_info["video_fps"])
self.assertTrue(video_tensor.equal(out_video_tensor))
audio_stream = av.open(f_name).streams.audio[0]
out_audio_stream = av.open(out_f_name).streams.audio[0]
self.assertEqual(info["audio_fps"], out_info["audio_fps"])
self.assertEqual(audio_stream.rate, out_audio_stream.rate)
self.assertAlmostEqual(audio_stream.frames, out_audio_stream.frames, delta=1)
self.assertEqual(audio_stream.frame_size, out_audio_stream.frame_size)
# TODO add tests for audio # TODO add tests for audio
......
...@@ -55,6 +55,10 @@ def write_video( ...@@ -55,6 +55,10 @@ def write_video(
fps: float, fps: float,
video_codec: str = "libx264", video_codec: str = "libx264",
options: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None,
audio_array: Optional[torch.Tensor] = None,
audio_fps: Optional[float] = None,
audio_codec: Optional[str] = None,
audio_options: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
""" """
Writes a 4d tensor in [T, H, W, C] format in a video file Writes a 4d tensor in [T, H, W, C] format in a video file
...@@ -66,7 +70,20 @@ def write_video( ...@@ -66,7 +70,20 @@ def write_video(
video_array : Tensor[T, H, W, C] video_array : Tensor[T, H, W, C]
tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format
fps : Number fps : Number
frames per second video frames per second
video_codec : str
the name of the video codec, i.e. "libx264", "h264", etc.
options : Dict
dictionary containing options to be passed into the PyAV video stream
audio_array : Tensor[C, N]
tensor containing the audio, where C is the number of channels and N is the
number of samples
audio_fps : Number
audio sample rate, typically 44100 or 48000
audio_codec : str
the name of the audio codec, i.e. "mp3", "aac", etc.
audio_options : Dict
dictionary containing options to be passed into the PyAV audio stream
""" """
_check_av_available() _check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
...@@ -83,6 +100,41 @@ def write_video( ...@@ -83,6 +100,41 @@ def write_video(
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24"
stream.options = options or {} stream.options = options or {}
if audio_array is not None:
audio_format_dtypes = {
'dbl': '<f8',
'dblp': '<f8',
'flt': '<f4',
'fltp': '<f4',
's16': '<i2',
's16p': '<i2',
's32': '<i4',
's32p': '<i4',
'u8': 'u1',
'u8p': 'u1',
}
a_stream = container.add_stream(audio_codec, rate=audio_fps)
a_stream.options = audio_options or {}
num_channels = audio_array.shape[0]
audio_layout = "stereo" if num_channels > 1 else "mono"
audio_sample_fmt = container.streams.audio[0].format.name
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)
frame = av.AudioFrame.from_ndarray(
audio_array, format=audio_sample_fmt, layout=audio_layout
)
frame.sample_rate = audio_fps
for packet in a_stream.encode(frame):
container.mux(packet)
for packet in a_stream.encode():
container.mux(packet)
for img in video_array: for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format="rgb24") frame = av.VideoFrame.from_ndarray(img, format="rgb24")
frame.pict_type = "NONE" frame.pict_type = "NONE"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment