Commit a984872d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add file-like object support to Streaming API (#2400)

Summary:
This commit adds file-like object support to Streaming API.

## Features
- File-like objects are expected to implement `read(self, n)`.
- Additionally `seek(self, offset, whence)` is used if available.
- Without `seek` method, some formats cannot be decoded properly.
  - To work around this, one can use the existing `decoder` option to tell what decoder it should use.
  - The set of `decoder` and `decoder_option` arguments were added to `add_basic_[audio|video]_stream` method, similar to `add_[audio|video]_stream`.
  - So as to have the arguments common to both audio and video in front of the rest of the arguments, the order of the arguments are changed.
  - Also `dtype` and `format` arguments were changed to make them consistent across audio/video methods.

## Code structure

The approach is very similar to how file-like object is supported in sox-based I/O.
In Streaming API if the input src is string, it is passed to the implementation bound with TorchBind,
if the src has `read` attribute, it is passed to the same implementation bound via PyBind 11.

![Untitled drawing](https://user-images.githubusercontent.com/855818/169098391-6116afee-7b29-460d-b50d-1037bb8a359d.png)

## Refactoring involved
- Extracted to https://github.com/pytorch/audio/issues/2402
  - Some implementation in the original TorchBind surface layer is converted to Wrapper class so that they can be re-used from PyBind11 bindings. The wrapper class serves to simplify the binding.
  - `add_basic_[audio|video]_stream` methods were removed from C++ layer as it was just constructing string and passing it to `add_[audio|video]_stream` method, which is simpler to do in Python.
  - The original core Streamer implementation kept the use of types in `c10` namespace minimum. All the `c10::optional` and `c10::Dict` were converted to the equivalents of `std` at binding layer. But since they work fine with PyBind11, Streamer core methods deal them directly.

## TODO:
- [x] Check if it is possible to stream MP4 (yuv420p) from S3 and directly decode (with/without HW decoding).

Pull Request resolved: https://github.com/pytorch/audio/pull/2400

Reviewed By: carolineechen

Differential Revision: D36520073

Pulled By: mthrok

fbshipit-source-id: a11d981bbe99b1ff0cc356e46264ac8e76614bc6
parent 67762993
......@@ -33,6 +33,7 @@ libavfilter provides.
# It can
# - Load audio/video in variety of formats
# - Load audio/video from local/remote source
# - Load audio/video from file-like object
# - Load audio/video from microphone, camera and screen
# - Generate synthetic audio/video signals.
# - Load audio/video chunk by chunk
......@@ -51,7 +52,7 @@ libavfilter provides.
# `<some media source> -> <optional processing> -> <tensor>`
#
# If you have other forms that can be useful to your usecases,
# (such as integration with `torch.Tensor` type and file-like objects)
# (such as integration with `torch.Tensor` type)
# please file a feature request.
#
......@@ -60,11 +61,15 @@ libavfilter provides.
# --------------
#
import IPython
import matplotlib.pyplot as plt
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
try:
from torchaudio.io import StreamReader
except ModuleNotFoundError:
......@@ -87,8 +92,8 @@ except ModuleNotFoundError:
pass
raise
print(torch.__version__)
print(torchaudio.__version__)
import IPython
import matplotlib.pyplot as plt
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets"
AUDIO_URL = f"{base_url}/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
......@@ -102,7 +107,7 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse
# handle. Whichever source is used, the remaining processes
# (configuring the output, applying preprocessing) are same.
#
# 1. Common media formats
# 1. Common media formats (resource indicator of string type or file-like object)
# 2. Audio / Video devices
# 3. Synthetic audio / video sources
#
......@@ -110,8 +115,18 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse
# For the other streams, please refer to the
# `Advanced I/O streams` section.
#
# .. note::
#
# The coverage of the supported media (such as containers, codecs and protocols)
# depend on the FFmpeg libraries found in the system.
#
# If `StreamReader` raises an error opening a source, please check
# that `ffmpeg` command can handle it.
#
######################################################################
# Local files
# ~~~~~~~~~~~
#
# To open a media file, you can simply pass the path of the file to
# the constructor of `StreamReader`.
......@@ -132,12 +147,73 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse
# # Video file
# StreamReader(src="video.mpeg")
#
######################################################################
# Network protocols
# ~~~~~~~~~~~~~~~~~
#
# You can directly pass a URL as well.
#
# .. code::
#
# # Video on remote server
# StreamReader(src="https://example.com/video.mp4")
#
# # Playlist format
# StreamReader(src="https://example.com/playlist.m3u")
#
# # RTMP
# StreamReader(src="rtmp://example.com:1935/live/app")
#
######################################################################
# File-like objects
# ~~~~~~~~~~~~~~~~~
#
# You can also pass a file-like object. A file-like object must implement
# ``read`` method conforming to :py:attr:`io.RawIOBase.read`.
#
# If the given file-like object has ``seek`` method, StreamReader uses it
# as well. In this case the ``seek`` method is expected to conform to
# :py:attr:`io.IOBase.seek`.
#
# .. code::
#
# # Open as fileobj with seek support
# with open("input.mp4", "rb") as src:
# StreamReader(src=src)
#
# In case where third-party libraries implement ``seek`` so that it raises
# an error, you can write a wrapper class to mask the ``seek`` method.
#
# .. code::
#
# class Wrapper:
# def __init__(self, obj):
# self.obj = obj
#
# def read(self, n):
# return self.obj.read(n)
#
# .. code::
#
# import requests
#
# response = requests.get("https://example.com/video.mp4", stream=True)
# s = StreamReader(Wrapper(response.raw))
#
# .. code::
#
# import boto3
#
# response = boto3.client("s3").get_object(Bucket="my_bucket", Key="key")
# s = StreamReader(Wrapper(response["Body"]))
#
######################################################################
# Opening a headerless data
# ~~~~~~~~~~~~~~~~~~~~~~~~~
#
# If attempting to load headerless raw data, you can use ``format`` and
# ``option`` to specify the format of the data.
#
......@@ -213,8 +289,8 @@ for i in range(streamer.num_src_streams):
#
######################################################################
# 5.1. Default streams
# --------------------
# Default streams
# ~~~~~~~~~~~~~~~
#
# When there are multiple streams in the source, it is not immediately
# clear which stream should be used.
......@@ -227,8 +303,8 @@ for i in range(streamer.num_src_streams):
#
######################################################################
# 5.2. Configuring output streams
# -------------------------------
# Configuring output streams
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Once you know which source stream you want to use, then you can
# configure output streams with
......@@ -250,21 +326,25 @@ for i in range(streamer.num_src_streams):
# When the StreamReader buffered this number of chunks and is asked to pull
# more frames, StreamReader drops the old frames/chunks.
# - ``stream_index``: The index of the source stream.
# - ``decoder``: If provided, override the decoder. Useful if it fails to detect
# the codec.
# - ``decoder_option``: The option for the decoder.
#
# For audio output stream, you can provide the following additional
# parameters to change the audio properties.
#
# - ``sample_rate``: When provided, StreamReader resamples the audio on-the-fly.
# - ``dtype``: By default the StreamReader returns tensor of `float32` dtype,
# with sample values ranging `[-1, 1]`. By providing ``dtype`` argument
# - ``format``: By default the StreamReader returns tensor of `float32` dtype,
# with sample values ranging `[-1, 1]`. By providing ``format`` argument
# the resulting dtype and value range is changed.
# - ``sample_rate``: When provided, StreamReader resamples the audio on-the-fly.
#
# For video output stream, the following parameters are available.
#
# - ``format``: Image frame format. By default StreamReader returns
# frames in 8-bit 3 channel, in RGB order.
# - ``frame_rate``: Change the frame rate by dropping or duplicating
# frames. No interpolation is performed.
# - ``width``, ``height``: Change the image size.
# - ``format``: Change the image format.
#
######################################################################
......@@ -298,7 +378,7 @@ for i in range(streamer.num_src_streams):
# streamer.add_basic_video_stream(
# frames_per_chunk=10,
# frame_rate=30,
# format="RGB"
# format="rgb24"
# )
#
# # Stream video from source stream `j`,
......@@ -310,7 +390,7 @@ for i in range(streamer.num_src_streams):
# frame_rate=30,
# width=128,
# height=128,
# format="BGR"
# format="bgr24"
# )
#
......@@ -341,8 +421,8 @@ for i in range(streamer.num_src_streams):
#
######################################################################
# 5.3. Streaming
# --------------
# 6. Streaming
# ------------
#
# To stream media data, the streamer alternates the process of
# fetching and decoding the source data, and passing the resulting
......@@ -368,7 +448,7 @@ for i in range(streamer.num_src_streams):
#
######################################################################
# 6. Example
# 7. Example
# ----------
#
# Let's take an example video to configure the output streams.
......@@ -392,9 +472,9 @@ for i in range(streamer.num_src_streams):
#
######################################################################
# Opening the source media
# ~~~~~~~~~~~~~~~~~~~~~~~~
#
# 6.1. Opening the source media
# ------------------------------
# Firstly, let's list the available streams and its properties.
#
......@@ -406,8 +486,8 @@ for i in range(streamer.num_src_streams):
#
# Now we configure the output stream.
#
# 6.2. Configuring ouptut streams
# -------------------------------
# Configuring ouptut streams
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# fmt: off
# Audio stream with 8k Hz
......@@ -428,7 +508,7 @@ streamer.add_basic_video_stream(
frame_rate=1,
width=960,
height=540,
format="RGB",
format="rgb24",
)
# Video stream with 320x320 (stretched) at 3 FPS, grayscale
......@@ -437,7 +517,7 @@ streamer.add_basic_video_stream(
frame_rate=3,
width=320,
height=320,
format="GRAY",
format="gray",
)
# fmt: on
......@@ -466,8 +546,8 @@ for i in range(streamer.num_out_streams):
print(streamer.get_out_stream_info(i))
######################################################################
# 6.3. Streaming
# --------------
# Streaming
# ~~~~~~~~~
#
######################################################################
......@@ -542,7 +622,9 @@ plt.show(block=False)
#
# .. seealso::
#
# `Device ASR with Emformer RNN-T <./device_asr.html>`__.
# - `Accelerated Video Decoding with NVDEC <../hw_acceleration_tutorial.html>`__.
# - `Online ASR with Emformer RNN-T <./online_asr_tutorial.html>`__.
# - `Device ASR with Emformer RNN-T <./device_asr.html>`__.
#
# Given that the system has proper media devices and libavdevice is
# configured to use the devices, the streaming API can
......@@ -622,14 +704,13 @@ plt.show(block=False)
#
######################################################################
# 2.1. Synthetic audio examples
# -----------------------------
# Synthetic audio examples
# ------------------------
#
######################################################################
# Sine wave with
# ~~~~~~~~~~~~~~
#
# Sine wave
# ~~~~~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#sine
#
# .. code::
......@@ -675,8 +756,8 @@ plt.show(block=False)
#
######################################################################
# Generate noise with
# ~~~~~~~~~~~~~~~~~~~
# Noise
# ~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#anoisesrc
#
# .. code::
......@@ -694,8 +775,8 @@ plt.show(block=False)
#
######################################################################
# 2.2. Synthetic video examples
# -----------------------------
# Synthetic video examples
# ------------------------
#
######################################################################
......@@ -811,8 +892,8 @@ plt.show(block=False)
#
######################################################################
# 3.1. Custom audio streams
# -------------------------
# Custom audio streams
# --------------------
#
#
......@@ -897,8 +978,8 @@ _display(2)
_display(3)
######################################################################
# 3.2. Custom video streams
# -------------------------
# Custom video streams
# --------------------
#
# fmt: off
......
......@@ -36,7 +36,6 @@ class TempDirMixin:
@classmethod
def tearDownClass(cls):
super().tearDownClass()
if cls.temp_dir_ is not None:
try:
cls.temp_dir_.cleanup()
......@@ -52,6 +51,7 @@ class TempDirMixin:
#
# Following the above thread, we ignore it.
pass
super().tearDownClass()
def get_temp_path(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
......
import torch
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import (
get_asset_path,
get_image,
......@@ -22,14 +22,46 @@ if is_ffmpeg_available():
)
def get_video_asset(file="nasa_13013.mp4"):
return get_asset_path(file)
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class(
("test_fileobj",),
[(False,), (True,)],
class_name_func=lambda cls, _, params: f'{cls.__name__}{"_fileobj" if params["test_fileobj"] else "_path"}',
)
class _MediaSourceMixin:
def setUp(self):
super().setUp()
self.src = None
def get_src(self, path):
if not self.test_fileobj:
return path
if self.src is not None:
raise ValueError("get_video_asset can be called only once.")
self.src = open(path, "rb")
return self.src
def tearDown(self):
if self.src is not None:
self.src.close()
super().tearDown()
################################################################################
@skipIfNoFFmpeg
class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
@_media_source
class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
"""Test suite for interface behaviors around StreamReader"""
def get_src(self, file="nasa_13013.mp4"):
return super().get_src(get_asset_path(file))
def test_streamer_invalid_input(self):
"""StreamReader constructor does not segfault but raise an exception when the input is invalid"""
with self.assertRaises(RuntimeError):
......@@ -48,14 +80,13 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_streamer_invalide_option(self, invalid_keys, options):
"""When invalid options are given, StreamReader raises an exception with these keys"""
options.update({k: k for k in invalid_keys})
src = get_video_asset()
with self.assertRaises(RuntimeError) as ctx:
StreamReader(src, option=options)
StreamReader(self.get_src(), option=options)
assert all(f'"{k}"' in str(ctx.exception) for k in invalid_keys)
def test_src_info(self):
"""`get_src_stream_info` properly fetches information"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
assert s.num_src_streams == 6
expected = [
......@@ -112,35 +143,35 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
bit_rate=None,
),
]
for i, exp in enumerate(expected):
assert exp == s.get_src_stream_info(i)
output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output
def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
for i in [-1, 6, 7, 8]:
with self.assertRaises(IndexError):
with self.assertRaises(RuntimeError):
s.get_src_stream_info(i)
def test_default_streams(self):
"""default stream is not None"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
assert s.default_audio_stream is not None
assert s.default_video_stream is not None
def test_default_audio_stream_none(self):
"""default audio stream is None for video without audio"""
s = StreamReader(get_video_asset("nasa_13013_no_audio.mp4"))
s = StreamReader(self.get_src("nasa_13013_no_audio.mp4"))
assert s.default_audio_stream is None
def test_default_video_stream_none(self):
"""default video stream is None for video with only audio"""
s = StreamReader(get_video_asset("nasa_13013_no_video.mp4"))
s = StreamReader(self.get_src("nasa_13013_no_video.mp4"))
assert s.default_video_stream is None
def test_num_out_stream(self):
"""num_out_streams gives the correct count of output streams"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
n, m = 6, 4
for i in range(n):
assert s.num_out_streams == i
......@@ -158,10 +189,10 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_basic_audio_stream(self):
"""`add_basic_audio_stream` constructs a correct filter."""
s = StreamReader(get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=None)
s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=-1, format=None)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=torch.int16)
s.add_basic_audio_stream(frames_per_chunk=-1, format="s16p")
sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream
......@@ -177,11 +208,11 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_basic_video_stream(self):
"""`add_basic_video_stream` constructs a correct filter."""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
s.add_basic_video_stream(frames_per_chunk=-1, format=None)
s.add_basic_video_stream(frames_per_chunk=-1, width=3, height=5)
s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=7)
s.add_basic_video_stream(frames_per_chunk=-1, format="BGR")
s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream
......@@ -201,7 +232,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_remove_streams(self):
"""`remove_stream` removes the correct output stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=24000)
s.add_basic_video_stream(frames_per_chunk=-1, width=16, height=16)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
......@@ -221,21 +252,21 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_remove_stream_invalid(self):
"""Attempt to remove invalid output streams raises IndexError"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
for i in range(-3, 3):
with self.assertRaises(IndexError):
with self.assertRaises(RuntimeError):
s.remove_stream(i)
s.add_audio_stream(frames_per_chunk=-1)
for i in range(-3, 3):
if i == 0:
continue
with self.assertRaises(IndexError):
with self.assertRaises(RuntimeError):
s.remove_stream(i)
def test_process_packet(self):
"""`process_packet` method returns 0 while there is a packet in source stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
# nasa_1013.mp3 contains 1023 packets.
for _ in range(1023):
code = s.process_packet()
......@@ -246,19 +277,19 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_pop_chunks_no_output_stream(self):
"""`pop_chunks` method returns empty list when there is no output stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
assert s.pop_chunks() == []
def test_pop_chunks_empty_buffer(self):
"""`pop_chunks` method returns None when a buffer is empty"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=-1)
s.add_basic_video_stream(frames_per_chunk=-1)
assert s.pop_chunks() == [None, None]
def test_pop_chunks_exhausted_stream(self):
"""`pop_chunks` method returns None when the source stream is exhausted"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
# video is 16.57 seconds.
# audio streams per 10 second chunk
# video streams per 20 second chunk
......@@ -284,14 +315,14 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_stream_empty(self):
"""`stream` fails when no output stream is configured"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
with self.assertRaises(RuntimeError):
next(s.stream())
def test_stream_smoke_test(self):
"""`stream` streams chunks fine"""
w, h = 256, 198
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h)
for i, (achunk, vchunk) in enumerate(s.stream()):
......@@ -302,7 +333,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_src())
for i in range(10):
s.seek(i)
for _ in range(0):
......@@ -312,13 +343,14 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = StreamReader(get_video_asset())
with self.assertRaises(ValueError):
s = StreamReader(self.get_src())
with self.assertRaises(RuntimeError):
s.seek(-1.0)
@skipIfNoFFmpeg
class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
@_media_source
class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
"""Test suite for audio streaming"""
def _get_reference_wav(self, sample_rate, channels_first=False, **kwargs):
......@@ -327,9 +359,14 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
save_wav(path, data, sample_rate, channels_first=channels_first)
return path, data
def _test_wav(self, path, original, dtype):
s = StreamReader(path)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=dtype)
def get_src(self, *args, **kwargs):
path, data = self._get_reference_wav(*args, **kwargs)
src = super().get_src(path)
return src, data
def _test_wav(self, src, original, fmt):
s = StreamReader(src)
s.add_basic_audio_stream(frames_per_chunk=-1, format=fmt)
s.process_all_packets()
(output,) = s.pop_chunks()
self.assertEqual(original, output)
......@@ -340,12 +377,19 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
)
def test_basic_audio_stream(self, dtype, num_channels):
"""`basic_audio_stream` can load WAV file properly."""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels)
src, original = self.get_src(8000, dtype=dtype, num_channels=num_channels)
fmt = {
"uint8": "u8p",
"int16": "s16p",
"int32": "s32p",
}[dtype]
# provide the matching dtype
self._test_wav(path, original, getattr(torch, dtype))
# use the internal dtype ffmpeg picks
self._test_wav(path, original, None)
self._test_wav(src, original, fmt=fmt)
if not self.test_fileobj:
# use the internal dtype ffmpeg picks
self._test_wav(src, original, fmt=None)
@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
......@@ -353,11 +397,11 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
)
def test_audio_stream(self, dtype, num_channels):
"""`add_audio_stream` can apply filter"""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels)
src, original = self.get_src(8000, dtype=dtype, num_channels=num_channels)
expected = torch.flip(original, dims=(0,))
s = StreamReader(path)
s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1, filter_desc="areverse")
s.process_all_packets()
(output,) = s.pop_chunks()
......@@ -369,10 +413,13 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
)
def test_audio_seek(self, dtype, num_channels):
"""`seek` changes the position properly"""
path, original = self._get_reference_wav(1, dtype=dtype, num_channels=num_channels, num_frames=30)
src, original = self.get_src(1, dtype=dtype, num_channels=num_channels, num_frames=30)
for t in range(10, 20):
expected = original[t:, :]
s = StreamReader(path)
if self.test_fileobj:
src.seek(0)
s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1)
s.seek(float(t))
s.process_all_packets()
......@@ -381,9 +428,9 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
def test_audio_seek_multiple(self):
"""Calling `seek` after streaming is started should change the position properly"""
path, original = self._get_reference_wav(1, dtype="int16", num_channels=2, num_frames=30)
src, original = self.get_src(1, dtype="int16", num_channels=2, num_frames=30)
s = StreamReader(path)
s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1)
ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20))
......@@ -405,11 +452,11 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
def test_audio_frames_per_chunk(self, frame_param, num_channels):
"""Different chunk parameter covers the source media properly"""
num_frames, frames_per_chunk, buffer_chunk_size = frame_param
path, original = self._get_reference_wav(
src, original = self.get_src(
8000, dtype="int16", num_channels=num_channels, num_frames=num_frames, channels_first=False
)
s = StreamReader(path)
s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=frames_per_chunk, buffer_chunk_size=buffer_chunk_size)
i, outputs = 0, []
for (output,) in s.stream():
......@@ -422,13 +469,19 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
@skipIfNoFFmpeg
class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase):
@_media_source
class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
def _get_reference_png(self, width: int, height: int, grayscale: bool):
original = get_image(width, height, grayscale=grayscale)
path = self.get_temp_path("ref.png")
save_image(path, original, mode="L" if grayscale else "RGB")
return path, original
def get_src(self, *args, **kwargs):
path, data = self._get_reference_png(*args, **kwargs)
src = super().get_src(path)
return src, data
def _test_png(self, path, original, format=None):
s = StreamReader(path)
s.add_basic_video_stream(frames_per_chunk=-1, format=format)
......@@ -441,9 +494,9 @@ class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase):
# TODO:
# Add test with alpha channel (RGBA, ARGB, BGRA, ABGR)
w, h = 32, 18
path, original = self._get_reference_png(w, h, grayscale=grayscale)
src, original = self.get_src(w, h, grayscale=grayscale)
expected = original[None, ...]
self._test_png(path, expected)
self._test_png(src, expected)
@parameterized.expand(
[
......@@ -453,10 +506,10 @@ class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase):
)
def test_png_effect(self, filter_desc, index):
h, w = 111, 250
path, original = self._get_reference_png(w, h, grayscale=False)
src, original = self.get_src(w, h, grayscale=False)
expected = torch.flip(original, dims=(index,))[None, ...]
s = StreamReader(path)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=-1, filter_desc=filter_desc)
s.process_all_packets()
output = s.pop_chunks()[0]
......
......@@ -57,7 +57,12 @@ def get_ext_modules():
]
)
if _USE_FFMPEG:
modules.append(Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]))
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio._torchaudio_ffmpeg", sources=[]),
]
)
return modules
......
# the following line is added in order to export symbols when building on Windows
# this approach has some limitations as documented in https://github.com/pytorch/pytorch/pull/3650
if (MSVC)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
################################################################################
# libtorchaudio
################################################################################
......@@ -204,11 +210,11 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
set(ADDITIONAL_ITEMS Python3::Python)
endif()
function(define_extension name sources libraries definitions)
function(define_extension name sources include_dirs libraries definitions)
add_library(${name} SHARED ${sources})
target_compile_definitions(${name} PRIVATE "${definitions}")
target_include_directories(
${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR})
${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR} ${include_dirs})
target_link_libraries(
${name}
${libraries}
......@@ -254,6 +260,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
define_extension(
_torchaudio
"${EXTENSION_SOURCES}"
""
libtorchaudio
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
)
......@@ -265,8 +272,23 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
define_extension(
_torchaudio_decoder
"${DECODER_EXTENSION_SOURCES}"
""
"libtorchaudio_decoder"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
if(USE_FFMPEG)
set(
FFMPEG_EXTENSION_SOURCES
ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp
)
define_extension(
_torchaudio_ffmpeg
"${FFMPEG_EXTENSION_SOURCES}"
"${FFMPEG_INCLUDE_DIRS}"
"libtorchaudio_ffmpeg"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
endif()
......@@ -66,17 +66,24 @@ std::string join(std::vector<std::string> vars) {
AVFormatContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const OptionDict& option) {
AVFormatContext* pFormat = NULL;
const OptionDict& option,
AVIOContext* io_ctx) {
AVFormatContext* pFormat = avformat_alloc_context();
if (!pFormat) {
throw std::runtime_error("Failed to allocate AVFormatContext.");
}
if (io_ctx) {
pFormat->pb = io_ctx;
}
AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* {
auto* pInput = [&]() -> AVINPUT_FORMAT_CONST AVInputFormat* {
if (device.has_value()) {
std::string device_str = device.value();
AVINPUT_FORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str());
if (!p) {
std::ostringstream msg;
msg << "Unsupported device: \"" << device_str << "\"";
msg << "Unsupported device/format: \"" << device_str << "\"";
throw std::runtime_error(msg.str());
}
return p;
......@@ -103,6 +110,17 @@ AVFormatContextPtr get_input_format_context(
AVFormatContextPtr::AVFormatContextPtr(AVFormatContext* p)
: Wrapper<AVFormatContext, AVFormatContextDeleter>(p) {}
////////////////////////////////////////////////////////////////////////////////
// AVIO
////////////////////////////////////////////////////////////////////////////////
void AVIOContextDeleter::operator()(AVIOContext* p) {
av_freep(&p->buffer);
av_freep(&p);
};
AVIOContextPtr::AVIOContextPtr(AVIOContext* p)
: Wrapper<AVIOContext, AVIOContextDeleter>(p) {}
////////////////////////////////////////////////////////////////////////////////
// AVPacket
////////////////////////////////////////////////////////////////////////////////
......
......@@ -13,6 +13,7 @@ extern "C" {
#include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h>
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/avutil.h>
#include <libavutil/frame.h>
#include <libavutil/imgutils.h>
......@@ -74,7 +75,19 @@ struct AVFormatContextPtr
AVFormatContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const OptionDict& option);
const OptionDict& option,
AVIOContext* io_ctx = nullptr);
////////////////////////////////////////////////////////////////////////////////
// AVIO
////////////////////////////////////////////////////////////////////////////////
struct AVIOContextDeleter {
void operator()(AVIOContext* p);
};
struct AVIOContextPtr : public Wrapper<AVIOContext, AVIOContextDeleter> {
explicit AVIOContextPtr(AVIOContext* p);
};
////////////////////////////////////////////////////////////////////////////////
// AVPacket
......
......@@ -46,84 +46,70 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
av_log_set_level(AV_LOG_ERROR);
});
m.def("torchaudio::ffmpeg_load", load);
m.class_<StreamReaderBinding>("ffmpeg_Streamer");
m.def("torchaudio::ffmpeg_streamer_init", init);
m.def("torchaudio::ffmpeg_streamer_num_src_streams", [](S s) {
return s->num_src_streams();
});
m.def("torchaudio::ffmpeg_streamer_num_out_streams", [](S s) {
return s->num_out_streams();
});
m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", [](S s, int64_t i) {
return s->get_src_stream_info(i);
});
m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", [](S s, int64_t i) {
return s->get_out_stream_info(i);
});
m.def("torchaudio::ffmpeg_streamer_find_best_audio_stream", [](S s) {
return s->find_best_audio_stream();
});
m.def("torchaudio::ffmpeg_streamer_find_best_video_stream", [](S s) {
return s->find_best_video_stream();
});
m.def("torchaudio::ffmpeg_streamer_seek", [](S s, double t) {
return s->seek(t);
});
m.def(
"torchaudio::ffmpeg_streamer_add_audio_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options));
});
m.def(
"torchaudio::ffmpeg_streamer_add_video_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options),
hw_accel);
});
m.def("torchaudio::ffmpeg_streamer_remove_stream", [](S s, int64_t i) {
s->remove_stream(i);
});
m.def(
"torchaudio::ffmpeg_streamer_process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->process_packet(timeout, backoff);
});
m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) {
s->process_all_packets();
});
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", [](S s) {
return s->is_buffer_ready();
});
m.def("torchaudio::ffmpeg_streamer_pop_chunks", [](S s) {
return s->pop_chunks();
});
m.class_<StreamReaderBinding>("ffmpeg_Streamer")
.def(torch::init<>(init))
.def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); })
.def(
"get_src_stream_info",
[](S s, int64_t i) { return s->get_src_stream_info(i); })
.def(
"get_out_stream_info",
[](S s, int64_t i) { return s->get_out_stream_info(i); })
.def(
"find_best_audio_stream",
[](S s) { return s->find_best_audio_stream(); })
.def(
"find_best_video_stream",
[](S s) { return s->find_best_video_stream(); })
.def("seek", [](S s, double t) { return s->seek(t); })
.def(
"add_audio_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options));
})
.def(
"add_video_stream",
[](S s,
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<c10::Dict<std::string, std::string>>&
decoder_options,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
map(decoder_options),
hw_accel);
})
.def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
.def(
"process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->process_packet(timeout, backoff);
})
.def("process_all_packets", [](S s) { s->process_all_packets(); })
.def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); })
.def("pop_chunks", [](S s) { return s->pop_chunks(); });
}
} // namespace
......
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj")
.def(py::init<
py::object,
const c10::optional<std::string>&,
const c10::optional<OptionDict>&,
int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams)
.def(
"find_best_audio_stream",
&StreamReaderFileObj::find_best_audio_stream)
.def(
"find_best_video_stream",
&StreamReaderFileObj::find_best_video_stream)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info)
.def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
.def("seek", &StreamReaderFileObj::seek)
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
.def("add_video_stream", &StreamReaderFileObj::add_video_stream)
.def("remove_stream", &StreamReaderFileObj::remove_stream)
.def("process_packet", &StreamReaderFileObj::process_packet)
.def("process_all_packets", &StreamReaderFileObj::process_all_packets)
.def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
.def("pop_chunks", &StreamReaderFileObj::pop_chunks);
}
} // namespace
} // namespace ffmpeg
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += chunk_len;
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}
// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}
StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size)),
StreamReaderBinding(get_input_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
option.value_or(OptionDict{}),
pAVIO)) {}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};
// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
class StreamReaderFileObj : protected FileObj, public StreamReaderBinding {
public:
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
int64_t buffer_size);
};
} // namespace ffmpeg
} // namespace torchaudio
......@@ -21,12 +21,12 @@ void Streamer::validate_open_stream() const {
void Streamer::validate_src_stream_index(int i) const {
validate_open_stream();
if (i < 0 || i >= static_cast<int>(pFormatContext->nb_streams))
throw std::out_of_range("Source stream index out of range");
throw std::runtime_error("Source stream index out of range");
}
void Streamer::validate_output_stream_index(int i) const {
if (i < 0 || i >= static_cast<int>(stream_indices.size()))
throw std::out_of_range("Output stream index out of range");
throw std::runtime_error("Output stream index out of range");
}
void Streamer::validate_src_stream_type(int i, AVMediaType type) {
......@@ -81,19 +81,25 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const {
ret.codec_long_name = desc->long_name;
}
switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO:
ret.fmt_name =
av_get_sample_fmt_name(static_cast<AVSampleFormat>(codecpar->format));
case AVMEDIA_TYPE_AUDIO: {
AVSampleFormat smp_fmt = static_cast<AVSampleFormat>(codecpar->format);
if (smp_fmt != AV_SAMPLE_FMT_NONE) {
ret.fmt_name = av_get_sample_fmt_name(smp_fmt);
}
ret.sample_rate = static_cast<double>(codecpar->sample_rate);
ret.num_channels = codecpar->channels;
break;
case AVMEDIA_TYPE_VIDEO:
ret.fmt_name =
av_get_pix_fmt_name(static_cast<AVPixelFormat>(codecpar->format));
}
case AVMEDIA_TYPE_VIDEO: {
AVPixelFormat pix_fmt = static_cast<AVPixelFormat>(codecpar->format);
if (pix_fmt != AV_PIX_FMT_NONE) {
ret.fmt_name = av_get_pix_fmt_name(pix_fmt);
}
ret.width = codecpar->width;
ret.height = codecpar->height;
ret.frame_rate = av_q2d(stream->r_frame_rate);
break;
}
default:;
}
return ret;
......@@ -137,7 +143,7 @@ bool Streamer::is_buffer_ready() const {
////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) {
if (timestamp < 0) {
throw std::invalid_argument("timestamp must be non-negative.");
throw std::runtime_error("timestamp must be non-negative.");
}
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
......@@ -220,6 +226,13 @@ void Streamer::add_stream(
validate_src_stream_type(i, media_type);
AVStream* stream = pFormatContext->streams[i];
// When media source is file-like object, it is possible that source codec is
// not detected properly.
if (stream->codecpar->format == -1) {
throw std::runtime_error(
"Failed to detect the source stream format. Please provide the decoder to use.");
}
stream->discard = AVDISCARD_DEFAULT;
if (!processors[i])
processors[i] = std::make_unique<StreamProcessor>(
......
......@@ -14,6 +14,7 @@ def _init_extension():
try:
torchaudio._extension._load_lib("libtorchaudio_ffmpeg")
import torchaudio._torchaudio_ffmpeg
except OSError as err:
raise ImportError(
"Stream API requires FFmpeg libraries (libavformat and such). Please install FFmpeg 4."
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment