Commit a984872d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add file-like object support to Streaming API (#2400)

Summary:
This commit adds file-like object support to Streaming API.

## Features
- File-like objects are expected to implement `read(self, n)`.
- Additionally `seek(self, offset, whence)` is used if available.
- Without `seek` method, some formats cannot be decoded properly.
  - To work around this, one can use the existing `decoder` option to tell what decoder it should use.
  - The set of `decoder` and `decoder_option` arguments were added to `add_basic_[audio|video]_stream` method, similar to `add_[audio|video]_stream`.
  - So as to have the arguments common to both audio and video in front of the rest of the arguments, the order of the arguments are changed.
  - Also `dtype` and `format` arguments were changed to make them consistent across audio/video methods.

## Code structure

The approach is very similar to how file-like object is supported in sox-based I/O.
In Streaming API if the input src is string, it is passed to the implementation bound with TorchBind,
if the src has `read` attribute, it is passed to the same implementation bound via PyBind 11.

![Untitled drawing](https://user-images.githubusercontent.com/855818/169098391-6116afee-7b29-460d-b50d-1037bb8a359d.png)

## Refactoring involved
- Extracted to https://github.com/pytorch/audio/issues/2402
  - Some implementation in the original TorchBind surface layer is converted to Wrapper class so that they can be re-used from PyBind11 bindings. The wrapper class serves to simplify the binding.
  - `add_basic_[audio|video]_stream` methods were removed from C++ layer as it was just constructing string and passing it to `add_[audio|video]_stream` method, which is simpler to do in Python.
  - The original core Streamer implementation kept the use of types in `c10` namespace minimum. All the `c10::optional` and `c10::Dict` were converted to the equivalents of `std` at binding layer. But since they work fine with PyBind11, Streamer core methods deal them directly.

## TODO:
- [x] Check if it is possible to stream MP4 (yuv420p) from S3 and directly decode (with/without HW decoding).

Pull Request resolved: https://github.com/pytorch/audio/pull/2400

Reviewed By: carolineechen

Differential Revision: D36520073

Pulled By: mthrok

fbshipit-source-id: a11d981bbe99b1ff0cc356e46264ac8e76614bc6
parent 67762993
...@@ -33,6 +33,7 @@ libavfilter provides. ...@@ -33,6 +33,7 @@ libavfilter provides.
# It can # It can
# - Load audio/video in variety of formats # - Load audio/video in variety of formats
# - Load audio/video from local/remote source # - Load audio/video from local/remote source
# - Load audio/video from file-like object
# - Load audio/video from microphone, camera and screen # - Load audio/video from microphone, camera and screen
# - Generate synthetic audio/video signals. # - Generate synthetic audio/video signals.
# - Load audio/video chunk by chunk # - Load audio/video chunk by chunk
...@@ -51,7 +52,7 @@ libavfilter provides. ...@@ -51,7 +52,7 @@ libavfilter provides.
# `<some media source> -> <optional processing> -> <tensor>` # `<some media source> -> <optional processing> -> <tensor>`
# #
# If you have other forms that can be useful to your usecases, # If you have other forms that can be useful to your usecases,
# (such as integration with `torch.Tensor` type and file-like objects) # (such as integration with `torch.Tensor` type)
# please file a feature request. # please file a feature request.
# #
...@@ -60,11 +61,15 @@ libavfilter provides. ...@@ -60,11 +61,15 @@ libavfilter provides.
# -------------- # --------------
# #
import IPython
import matplotlib.pyplot as plt
import torch import torch
import torchaudio import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
try: try:
from torchaudio.io import StreamReader from torchaudio.io import StreamReader
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -87,8 +92,8 @@ except ModuleNotFoundError: ...@@ -87,8 +92,8 @@ except ModuleNotFoundError:
pass pass
raise raise
print(torch.__version__) import IPython
print(torchaudio.__version__) import matplotlib.pyplot as plt
base_url = "https://download.pytorch.org/torchaudio/tutorial-assets" base_url = "https://download.pytorch.org/torchaudio/tutorial-assets"
AUDIO_URL = f"{base_url}/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" AUDIO_URL = f"{base_url}/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
...@@ -102,7 +107,7 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse ...@@ -102,7 +107,7 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse
# handle. Whichever source is used, the remaining processes # handle. Whichever source is used, the remaining processes
# (configuring the output, applying preprocessing) are same. # (configuring the output, applying preprocessing) are same.
# #
# 1. Common media formats # 1. Common media formats (resource indicator of string type or file-like object)
# 2. Audio / Video devices # 2. Audio / Video devices
# 3. Synthetic audio / video sources # 3. Synthetic audio / video sources
# #
...@@ -110,8 +115,18 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse ...@@ -110,8 +115,18 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse
# For the other streams, please refer to the # For the other streams, please refer to the
# `Advanced I/O streams` section. # `Advanced I/O streams` section.
# #
# .. note::
#
# The coverage of the supported media (such as containers, codecs and protocols)
# depend on the FFmpeg libraries found in the system.
#
# If `StreamReader` raises an error opening a source, please check
# that `ffmpeg` command can handle it.
#
###################################################################### ######################################################################
# Local files
# ~~~~~~~~~~~
# #
# To open a media file, you can simply pass the path of the file to # To open a media file, you can simply pass the path of the file to
# the constructor of `StreamReader`. # the constructor of `StreamReader`.
...@@ -132,12 +147,73 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse ...@@ -132,12 +147,73 @@ VIDEO_URL = f"{base_url}/stream-api/NASAs_Most_Scientifically_Complex_Space_Obse
# # Video file # # Video file
# StreamReader(src="video.mpeg") # StreamReader(src="video.mpeg")
# #
######################################################################
# Network protocols
# ~~~~~~~~~~~~~~~~~
#
# You can directly pass a URL as well.
#
# .. code::
#
# # Video on remote server # # Video on remote server
# StreamReader(src="https://example.com/video.mp4") # StreamReader(src="https://example.com/video.mp4")
# #
# # Playlist format # # Playlist format
# StreamReader(src="https://example.com/playlist.m3u") # StreamReader(src="https://example.com/playlist.m3u")
# #
# # RTMP
# StreamReader(src="rtmp://example.com:1935/live/app")
#
######################################################################
# File-like objects
# ~~~~~~~~~~~~~~~~~
#
# You can also pass a file-like object. A file-like object must implement
# ``read`` method conforming to :py:attr:`io.RawIOBase.read`.
#
# If the given file-like object has ``seek`` method, StreamReader uses it
# as well. In this case the ``seek`` method is expected to conform to
# :py:attr:`io.IOBase.seek`.
#
# .. code::
#
# # Open as fileobj with seek support
# with open("input.mp4", "rb") as src:
# StreamReader(src=src)
#
# In case where third-party libraries implement ``seek`` so that it raises
# an error, you can write a wrapper class to mask the ``seek`` method.
#
# .. code::
#
# class Wrapper:
# def __init__(self, obj):
# self.obj = obj
#
# def read(self, n):
# return self.obj.read(n)
#
# .. code::
#
# import requests
#
# response = requests.get("https://example.com/video.mp4", stream=True)
# s = StreamReader(Wrapper(response.raw))
#
# .. code::
#
# import boto3
#
# response = boto3.client("s3").get_object(Bucket="my_bucket", Key="key")
# s = StreamReader(Wrapper(response["Body"]))
#
######################################################################
# Opening a headerless data
# ~~~~~~~~~~~~~~~~~~~~~~~~~
#
# If attempting to load headerless raw data, you can use ``format`` and # If attempting to load headerless raw data, you can use ``format`` and
# ``option`` to specify the format of the data. # ``option`` to specify the format of the data.
# #
...@@ -213,8 +289,8 @@ for i in range(streamer.num_src_streams): ...@@ -213,8 +289,8 @@ for i in range(streamer.num_src_streams):
# #
###################################################################### ######################################################################
# 5.1. Default streams # Default streams
# -------------------- # ~~~~~~~~~~~~~~~
# #
# When there are multiple streams in the source, it is not immediately # When there are multiple streams in the source, it is not immediately
# clear which stream should be used. # clear which stream should be used.
...@@ -227,8 +303,8 @@ for i in range(streamer.num_src_streams): ...@@ -227,8 +303,8 @@ for i in range(streamer.num_src_streams):
# #
###################################################################### ######################################################################
# 5.2. Configuring output streams # Configuring output streams
# ------------------------------- # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# Once you know which source stream you want to use, then you can # Once you know which source stream you want to use, then you can
# configure output streams with # configure output streams with
...@@ -250,21 +326,25 @@ for i in range(streamer.num_src_streams): ...@@ -250,21 +326,25 @@ for i in range(streamer.num_src_streams):
# When the StreamReader buffered this number of chunks and is asked to pull # When the StreamReader buffered this number of chunks and is asked to pull
# more frames, StreamReader drops the old frames/chunks. # more frames, StreamReader drops the old frames/chunks.
# - ``stream_index``: The index of the source stream. # - ``stream_index``: The index of the source stream.
# - ``decoder``: If provided, override the decoder. Useful if it fails to detect
# the codec.
# - ``decoder_option``: The option for the decoder.
# #
# For audio output stream, you can provide the following additional # For audio output stream, you can provide the following additional
# parameters to change the audio properties. # parameters to change the audio properties.
# #
# - ``sample_rate``: When provided, StreamReader resamples the audio on-the-fly. # - ``format``: By default the StreamReader returns tensor of `float32` dtype,
# - ``dtype``: By default the StreamReader returns tensor of `float32` dtype, # with sample values ranging `[-1, 1]`. By providing ``format`` argument
# with sample values ranging `[-1, 1]`. By providing ``dtype`` argument
# the resulting dtype and value range is changed. # the resulting dtype and value range is changed.
# - ``sample_rate``: When provided, StreamReader resamples the audio on-the-fly.
# #
# For video output stream, the following parameters are available. # For video output stream, the following parameters are available.
# #
# - ``format``: Image frame format. By default StreamReader returns
# frames in 8-bit 3 channel, in RGB order.
# - ``frame_rate``: Change the frame rate by dropping or duplicating # - ``frame_rate``: Change the frame rate by dropping or duplicating
# frames. No interpolation is performed. # frames. No interpolation is performed.
# - ``width``, ``height``: Change the image size. # - ``width``, ``height``: Change the image size.
# - ``format``: Change the image format.
# #
###################################################################### ######################################################################
...@@ -298,7 +378,7 @@ for i in range(streamer.num_src_streams): ...@@ -298,7 +378,7 @@ for i in range(streamer.num_src_streams):
# streamer.add_basic_video_stream( # streamer.add_basic_video_stream(
# frames_per_chunk=10, # frames_per_chunk=10,
# frame_rate=30, # frame_rate=30,
# format="RGB" # format="rgb24"
# ) # )
# #
# # Stream video from source stream `j`, # # Stream video from source stream `j`,
...@@ -310,7 +390,7 @@ for i in range(streamer.num_src_streams): ...@@ -310,7 +390,7 @@ for i in range(streamer.num_src_streams):
# frame_rate=30, # frame_rate=30,
# width=128, # width=128,
# height=128, # height=128,
# format="BGR" # format="bgr24"
# ) # )
# #
...@@ -341,8 +421,8 @@ for i in range(streamer.num_src_streams): ...@@ -341,8 +421,8 @@ for i in range(streamer.num_src_streams):
# #
###################################################################### ######################################################################
# 5.3. Streaming # 6. Streaming
# -------------- # ------------
# #
# To stream media data, the streamer alternates the process of # To stream media data, the streamer alternates the process of
# fetching and decoding the source data, and passing the resulting # fetching and decoding the source data, and passing the resulting
...@@ -368,7 +448,7 @@ for i in range(streamer.num_src_streams): ...@@ -368,7 +448,7 @@ for i in range(streamer.num_src_streams):
# #
###################################################################### ######################################################################
# 6. Example # 7. Example
# ---------- # ----------
# #
# Let's take an example video to configure the output streams. # Let's take an example video to configure the output streams.
...@@ -392,9 +472,9 @@ for i in range(streamer.num_src_streams): ...@@ -392,9 +472,9 @@ for i in range(streamer.num_src_streams):
# #
###################################################################### ######################################################################
# Opening the source media
# ~~~~~~~~~~~~~~~~~~~~~~~~
# #
# 6.1. Opening the source media
# ------------------------------
# Firstly, let's list the available streams and its properties. # Firstly, let's list the available streams and its properties.
# #
...@@ -406,8 +486,8 @@ for i in range(streamer.num_src_streams): ...@@ -406,8 +486,8 @@ for i in range(streamer.num_src_streams):
# #
# Now we configure the output stream. # Now we configure the output stream.
# #
# 6.2. Configuring ouptut streams # Configuring ouptut streams
# ------------------------------- # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# fmt: off # fmt: off
# Audio stream with 8k Hz # Audio stream with 8k Hz
...@@ -428,7 +508,7 @@ streamer.add_basic_video_stream( ...@@ -428,7 +508,7 @@ streamer.add_basic_video_stream(
frame_rate=1, frame_rate=1,
width=960, width=960,
height=540, height=540,
format="RGB", format="rgb24",
) )
# Video stream with 320x320 (stretched) at 3 FPS, grayscale # Video stream with 320x320 (stretched) at 3 FPS, grayscale
...@@ -437,7 +517,7 @@ streamer.add_basic_video_stream( ...@@ -437,7 +517,7 @@ streamer.add_basic_video_stream(
frame_rate=3, frame_rate=3,
width=320, width=320,
height=320, height=320,
format="GRAY", format="gray",
) )
# fmt: on # fmt: on
...@@ -466,8 +546,8 @@ for i in range(streamer.num_out_streams): ...@@ -466,8 +546,8 @@ for i in range(streamer.num_out_streams):
print(streamer.get_out_stream_info(i)) print(streamer.get_out_stream_info(i))
###################################################################### ######################################################################
# 6.3. Streaming # Streaming
# -------------- # ~~~~~~~~~
# #
###################################################################### ######################################################################
...@@ -542,7 +622,9 @@ plt.show(block=False) ...@@ -542,7 +622,9 @@ plt.show(block=False)
# #
# .. seealso:: # .. seealso::
# #
# `Device ASR with Emformer RNN-T <./device_asr.html>`__. # - `Accelerated Video Decoding with NVDEC <../hw_acceleration_tutorial.html>`__.
# - `Online ASR with Emformer RNN-T <./online_asr_tutorial.html>`__.
# - `Device ASR with Emformer RNN-T <./device_asr.html>`__.
# #
# Given that the system has proper media devices and libavdevice is # Given that the system has proper media devices and libavdevice is
# configured to use the devices, the streaming API can # configured to use the devices, the streaming API can
...@@ -622,14 +704,13 @@ plt.show(block=False) ...@@ -622,14 +704,13 @@ plt.show(block=False)
# #
###################################################################### ######################################################################
# 2.1. Synthetic audio examples # Synthetic audio examples
# ----------------------------- # ------------------------
# #
###################################################################### ######################################################################
# Sine wave with # Sine wave
# ~~~~~~~~~~~~~~ # ~~~~~~~~~
#
# https://ffmpeg.org/ffmpeg-filters.html#sine # https://ffmpeg.org/ffmpeg-filters.html#sine
# #
# .. code:: # .. code::
...@@ -675,8 +756,8 @@ plt.show(block=False) ...@@ -675,8 +756,8 @@ plt.show(block=False)
# #
###################################################################### ######################################################################
# Generate noise with # Noise
# ~~~~~~~~~~~~~~~~~~~ # ~~~~~
# https://ffmpeg.org/ffmpeg-filters.html#anoisesrc # https://ffmpeg.org/ffmpeg-filters.html#anoisesrc
# #
# .. code:: # .. code::
...@@ -694,8 +775,8 @@ plt.show(block=False) ...@@ -694,8 +775,8 @@ plt.show(block=False)
# #
###################################################################### ######################################################################
# 2.2. Synthetic video examples # Synthetic video examples
# ----------------------------- # ------------------------
# #
###################################################################### ######################################################################
...@@ -811,8 +892,8 @@ plt.show(block=False) ...@@ -811,8 +892,8 @@ plt.show(block=False)
# #
###################################################################### ######################################################################
# 3.1. Custom audio streams # Custom audio streams
# ------------------------- # --------------------
# #
# #
...@@ -897,8 +978,8 @@ _display(2) ...@@ -897,8 +978,8 @@ _display(2)
_display(3) _display(3)
###################################################################### ######################################################################
# 3.2. Custom video streams # Custom video streams
# ------------------------- # --------------------
# #
# fmt: off # fmt: off
......
...@@ -36,7 +36,6 @@ class TempDirMixin: ...@@ -36,7 +36,6 @@ class TempDirMixin:
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
super().tearDownClass()
if cls.temp_dir_ is not None: if cls.temp_dir_ is not None:
try: try:
cls.temp_dir_.cleanup() cls.temp_dir_.cleanup()
...@@ -52,6 +51,7 @@ class TempDirMixin: ...@@ -52,6 +51,7 @@ class TempDirMixin:
# #
# Following the above thread, we ignore it. # Following the above thread, we ignore it.
pass pass
super().tearDownClass()
def get_temp_path(self, *paths): def get_temp_path(self, *paths):
temp_dir = os.path.join(self.get_base_temp_dir(), self.id()) temp_dir = os.path.join(self.get_base_temp_dir(), self.id())
......
import torch import torch
from parameterized import parameterized from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
get_image, get_image,
...@@ -22,14 +22,46 @@ if is_ffmpeg_available(): ...@@ -22,14 +22,46 @@ if is_ffmpeg_available():
) )
def get_video_asset(file="nasa_13013.mp4"): ################################################################################
return get_asset_path(file) # Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class(
("test_fileobj",),
[(False,), (True,)],
class_name_func=lambda cls, _, params: f'{cls.__name__}{"_fileobj" if params["test_fileobj"] else "_path"}',
)
class _MediaSourceMixin:
def setUp(self):
super().setUp()
self.src = None
def get_src(self, path):
if not self.test_fileobj:
return path
if self.src is not None:
raise ValueError("get_video_asset can be called only once.")
self.src = open(path, "rb")
return self.src
def tearDown(self):
if self.src is not None:
self.src.close()
super().tearDown()
################################################################################
@skipIfNoFFmpeg @skipIfNoFFmpeg
class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): @_media_source
class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
"""Test suite for interface behaviors around StreamReader""" """Test suite for interface behaviors around StreamReader"""
def get_src(self, file="nasa_13013.mp4"):
return super().get_src(get_asset_path(file))
def test_streamer_invalid_input(self): def test_streamer_invalid_input(self):
"""StreamReader constructor does not segfault but raise an exception when the input is invalid""" """StreamReader constructor does not segfault but raise an exception when the input is invalid"""
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
...@@ -48,14 +80,13 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -48,14 +80,13 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_streamer_invalide_option(self, invalid_keys, options): def test_streamer_invalide_option(self, invalid_keys, options):
"""When invalid options are given, StreamReader raises an exception with these keys""" """When invalid options are given, StreamReader raises an exception with these keys"""
options.update({k: k for k in invalid_keys}) options.update({k: k for k in invalid_keys})
src = get_video_asset()
with self.assertRaises(RuntimeError) as ctx: with self.assertRaises(RuntimeError) as ctx:
StreamReader(src, option=options) StreamReader(self.get_src(), option=options)
assert all(f'"{k}"' in str(ctx.exception) for k in invalid_keys) assert all(f'"{k}"' in str(ctx.exception) for k in invalid_keys)
def test_src_info(self): def test_src_info(self):
"""`get_src_stream_info` properly fetches information""" """`get_src_stream_info` properly fetches information"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
assert s.num_src_streams == 6 assert s.num_src_streams == 6
expected = [ expected = [
...@@ -112,35 +143,35 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -112,35 +143,35 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
bit_rate=None, bit_rate=None,
), ),
] ]
for i, exp in enumerate(expected): output = [s.get_src_stream_info(i) for i in range(6)]
assert exp == s.get_src_stream_info(i) assert expected == output
def test_src_info_invalid_index(self): def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid""" """`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
for i in [-1, 6, 7, 8]: for i in [-1, 6, 7, 8]:
with self.assertRaises(IndexError): with self.assertRaises(RuntimeError):
s.get_src_stream_info(i) s.get_src_stream_info(i)
def test_default_streams(self): def test_default_streams(self):
"""default stream is not None""" """default stream is not None"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
assert s.default_audio_stream is not None assert s.default_audio_stream is not None
assert s.default_video_stream is not None assert s.default_video_stream is not None
def test_default_audio_stream_none(self): def test_default_audio_stream_none(self):
"""default audio stream is None for video without audio""" """default audio stream is None for video without audio"""
s = StreamReader(get_video_asset("nasa_13013_no_audio.mp4")) s = StreamReader(self.get_src("nasa_13013_no_audio.mp4"))
assert s.default_audio_stream is None assert s.default_audio_stream is None
def test_default_video_stream_none(self): def test_default_video_stream_none(self):
"""default video stream is None for video with only audio""" """default video stream is None for video with only audio"""
s = StreamReader(get_video_asset("nasa_13013_no_video.mp4")) s = StreamReader(self.get_src("nasa_13013_no_video.mp4"))
assert s.default_video_stream is None assert s.default_video_stream is None
def test_num_out_stream(self): def test_num_out_stream(self):
"""num_out_streams gives the correct count of output streams""" """num_out_streams gives the correct count of output streams"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
n, m = 6, 4 n, m = 6, 4
for i in range(n): for i in range(n):
assert s.num_out_streams == i assert s.num_out_streams == i
...@@ -158,10 +189,10 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -158,10 +189,10 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_basic_audio_stream(self): def test_basic_audio_stream(self):
"""`add_basic_audio_stream` constructs a correct filter.""" """`add_basic_audio_stream` constructs a correct filter."""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=None) s.add_basic_audio_stream(frames_per_chunk=-1, format=None)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000) s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=torch.int16) s.add_basic_audio_stream(frames_per_chunk=-1, format="s16p")
sinfo = s.get_out_stream_info(0) sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream assert sinfo.source_index == s.default_audio_stream
...@@ -177,11 +208,11 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -177,11 +208,11 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_basic_video_stream(self): def test_basic_video_stream(self):
"""`add_basic_video_stream` constructs a correct filter.""" """`add_basic_video_stream` constructs a correct filter."""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
s.add_basic_video_stream(frames_per_chunk=-1, format=None) s.add_basic_video_stream(frames_per_chunk=-1, format=None)
s.add_basic_video_stream(frames_per_chunk=-1, width=3, height=5) s.add_basic_video_stream(frames_per_chunk=-1, width=3, height=5)
s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=7) s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=7)
s.add_basic_video_stream(frames_per_chunk=-1, format="BGR") s.add_basic_video_stream(frames_per_chunk=-1, format="bgr24")
sinfo = s.get_out_stream_info(0) sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream assert sinfo.source_index == s.default_video_stream
...@@ -201,7 +232,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -201,7 +232,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_remove_streams(self): def test_remove_streams(self):
"""`remove_stream` removes the correct output stream""" """`remove_stream` removes the correct output stream"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=24000) s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=24000)
s.add_basic_video_stream(frames_per_chunk=-1, width=16, height=16) s.add_basic_video_stream(frames_per_chunk=-1, width=16, height=16)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000) s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
...@@ -221,21 +252,21 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -221,21 +252,21 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_remove_stream_invalid(self): def test_remove_stream_invalid(self):
"""Attempt to remove invalid output streams raises IndexError""" """Attempt to remove invalid output streams raises IndexError"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
for i in range(-3, 3): for i in range(-3, 3):
with self.assertRaises(IndexError): with self.assertRaises(RuntimeError):
s.remove_stream(i) s.remove_stream(i)
s.add_audio_stream(frames_per_chunk=-1) s.add_audio_stream(frames_per_chunk=-1)
for i in range(-3, 3): for i in range(-3, 3):
if i == 0: if i == 0:
continue continue
with self.assertRaises(IndexError): with self.assertRaises(RuntimeError):
s.remove_stream(i) s.remove_stream(i)
def test_process_packet(self): def test_process_packet(self):
"""`process_packet` method returns 0 while there is a packet in source stream""" """`process_packet` method returns 0 while there is a packet in source stream"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
# nasa_1013.mp3 contains 1023 packets. # nasa_1013.mp3 contains 1023 packets.
for _ in range(1023): for _ in range(1023):
code = s.process_packet() code = s.process_packet()
...@@ -246,19 +277,19 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -246,19 +277,19 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_pop_chunks_no_output_stream(self): def test_pop_chunks_no_output_stream(self):
"""`pop_chunks` method returns empty list when there is no output stream""" """`pop_chunks` method returns empty list when there is no output stream"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
assert s.pop_chunks() == [] assert s.pop_chunks() == []
def test_pop_chunks_empty_buffer(self): def test_pop_chunks_empty_buffer(self):
"""`pop_chunks` method returns None when a buffer is empty""" """`pop_chunks` method returns None when a buffer is empty"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=-1) s.add_basic_audio_stream(frames_per_chunk=-1)
s.add_basic_video_stream(frames_per_chunk=-1) s.add_basic_video_stream(frames_per_chunk=-1)
assert s.pop_chunks() == [None, None] assert s.pop_chunks() == [None, None]
def test_pop_chunks_exhausted_stream(self): def test_pop_chunks_exhausted_stream(self):
"""`pop_chunks` method returns None when the source stream is exhausted""" """`pop_chunks` method returns None when the source stream is exhausted"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
# video is 16.57 seconds. # video is 16.57 seconds.
# audio streams per 10 second chunk # audio streams per 10 second chunk
# video streams per 20 second chunk # video streams per 20 second chunk
...@@ -284,14 +315,14 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -284,14 +315,14 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_stream_empty(self): def test_stream_empty(self):
"""`stream` fails when no output stream is configured""" """`stream` fails when no output stream is configured"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
next(s.stream()) next(s.stream())
def test_stream_smoke_test(self): def test_stream_smoke_test(self):
"""`stream` streams chunks fine""" """`stream` streams chunks fine"""
w, h = 256, 198 w, h = 256, 198
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000) s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h) s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h)
for i, (achunk, vchunk) in enumerate(s.stream()): for i, (achunk, vchunk) in enumerate(s.stream()):
...@@ -302,7 +333,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -302,7 +333,7 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_seek(self): def test_seek(self):
"""Calling `seek` multiple times should not segfault""" """Calling `seek` multiple times should not segfault"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
for i in range(10): for i in range(10):
s.seek(i) s.seek(i)
for _ in range(0): for _ in range(0):
...@@ -312,13 +343,14 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -312,13 +343,14 @@ class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
def test_seek_negative(self): def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception""" """Calling `seek` with negative value should raise an exception"""
s = StreamReader(get_video_asset()) s = StreamReader(self.get_src())
with self.assertRaises(ValueError): with self.assertRaises(RuntimeError):
s.seek(-1.0) s.seek(-1.0)
@skipIfNoFFmpeg @skipIfNoFFmpeg
class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): @_media_source
class StreamReaderAudioTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
"""Test suite for audio streaming""" """Test suite for audio streaming"""
def _get_reference_wav(self, sample_rate, channels_first=False, **kwargs): def _get_reference_wav(self, sample_rate, channels_first=False, **kwargs):
...@@ -327,9 +359,14 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -327,9 +359,14 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
save_wav(path, data, sample_rate, channels_first=channels_first) save_wav(path, data, sample_rate, channels_first=channels_first)
return path, data return path, data
def _test_wav(self, path, original, dtype): def get_src(self, *args, **kwargs):
s = StreamReader(path) path, data = self._get_reference_wav(*args, **kwargs)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=dtype) src = super().get_src(path)
return src, data
def _test_wav(self, src, original, fmt):
s = StreamReader(src)
s.add_basic_audio_stream(frames_per_chunk=-1, format=fmt)
s.process_all_packets() s.process_all_packets()
(output,) = s.pop_chunks() (output,) = s.pop_chunks()
self.assertEqual(original, output) self.assertEqual(original, output)
...@@ -340,12 +377,19 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -340,12 +377,19 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
) )
def test_basic_audio_stream(self, dtype, num_channels): def test_basic_audio_stream(self, dtype, num_channels):
"""`basic_audio_stream` can load WAV file properly.""" """`basic_audio_stream` can load WAV file properly."""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels) src, original = self.get_src(8000, dtype=dtype, num_channels=num_channels)
fmt = {
"uint8": "u8p",
"int16": "s16p",
"int32": "s32p",
}[dtype]
# provide the matching dtype # provide the matching dtype
self._test_wav(path, original, getattr(torch, dtype)) self._test_wav(src, original, fmt=fmt)
# use the internal dtype ffmpeg picks if not self.test_fileobj:
self._test_wav(path, original, None) # use the internal dtype ffmpeg picks
self._test_wav(src, original, fmt=None)
@nested_params( @nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"] ["int16", "uint8", "int32"], # "float", "double", "int64"]
...@@ -353,11 +397,11 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -353,11 +397,11 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
) )
def test_audio_stream(self, dtype, num_channels): def test_audio_stream(self, dtype, num_channels):
"""`add_audio_stream` can apply filter""" """`add_audio_stream` can apply filter"""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels) src, original = self.get_src(8000, dtype=dtype, num_channels=num_channels)
expected = torch.flip(original, dims=(0,)) expected = torch.flip(original, dims=(0,))
s = StreamReader(path) s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1, filter_desc="areverse") s.add_audio_stream(frames_per_chunk=-1, filter_desc="areverse")
s.process_all_packets() s.process_all_packets()
(output,) = s.pop_chunks() (output,) = s.pop_chunks()
...@@ -369,10 +413,13 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -369,10 +413,13 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
) )
def test_audio_seek(self, dtype, num_channels): def test_audio_seek(self, dtype, num_channels):
"""`seek` changes the position properly""" """`seek` changes the position properly"""
path, original = self._get_reference_wav(1, dtype=dtype, num_channels=num_channels, num_frames=30) src, original = self.get_src(1, dtype=dtype, num_channels=num_channels, num_frames=30)
for t in range(10, 20): for t in range(10, 20):
expected = original[t:, :] expected = original[t:, :]
s = StreamReader(path) if self.test_fileobj:
src.seek(0)
s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1) s.add_audio_stream(frames_per_chunk=-1)
s.seek(float(t)) s.seek(float(t))
s.process_all_packets() s.process_all_packets()
...@@ -381,9 +428,9 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -381,9 +428,9 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
def test_audio_seek_multiple(self): def test_audio_seek_multiple(self):
"""Calling `seek` after streaming is started should change the position properly""" """Calling `seek` after streaming is started should change the position properly"""
path, original = self._get_reference_wav(1, dtype="int16", num_channels=2, num_frames=30) src, original = self.get_src(1, dtype="int16", num_channels=2, num_frames=30)
s = StreamReader(path) s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=-1) s.add_audio_stream(frames_per_chunk=-1)
ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20)) ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20))
...@@ -405,11 +452,11 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -405,11 +452,11 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
def test_audio_frames_per_chunk(self, frame_param, num_channels): def test_audio_frames_per_chunk(self, frame_param, num_channels):
"""Different chunk parameter covers the source media properly""" """Different chunk parameter covers the source media properly"""
num_frames, frames_per_chunk, buffer_chunk_size = frame_param num_frames, frames_per_chunk, buffer_chunk_size = frame_param
path, original = self._get_reference_wav( src, original = self.get_src(
8000, dtype="int16", num_channels=num_channels, num_frames=num_frames, channels_first=False 8000, dtype="int16", num_channels=num_channels, num_frames=num_frames, channels_first=False
) )
s = StreamReader(path) s = StreamReader(src)
s.add_audio_stream(frames_per_chunk=frames_per_chunk, buffer_chunk_size=buffer_chunk_size) s.add_audio_stream(frames_per_chunk=frames_per_chunk, buffer_chunk_size=buffer_chunk_size)
i, outputs = 0, [] i, outputs = 0, []
for (output,) in s.stream(): for (output,) in s.stream():
...@@ -422,13 +469,19 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase): ...@@ -422,13 +469,19 @@ class StreamReaderAudioTest(TempDirMixin, TorchaudioTestCase):
@skipIfNoFFmpeg @skipIfNoFFmpeg
class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase): @_media_source
class StreamReaderImageTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
def _get_reference_png(self, width: int, height: int, grayscale: bool): def _get_reference_png(self, width: int, height: int, grayscale: bool):
original = get_image(width, height, grayscale=grayscale) original = get_image(width, height, grayscale=grayscale)
path = self.get_temp_path("ref.png") path = self.get_temp_path("ref.png")
save_image(path, original, mode="L" if grayscale else "RGB") save_image(path, original, mode="L" if grayscale else "RGB")
return path, original return path, original
def get_src(self, *args, **kwargs):
path, data = self._get_reference_png(*args, **kwargs)
src = super().get_src(path)
return src, data
def _test_png(self, path, original, format=None): def _test_png(self, path, original, format=None):
s = StreamReader(path) s = StreamReader(path)
s.add_basic_video_stream(frames_per_chunk=-1, format=format) s.add_basic_video_stream(frames_per_chunk=-1, format=format)
...@@ -441,9 +494,9 @@ class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase): ...@@ -441,9 +494,9 @@ class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase):
# TODO: # TODO:
# Add test with alpha channel (RGBA, ARGB, BGRA, ABGR) # Add test with alpha channel (RGBA, ARGB, BGRA, ABGR)
w, h = 32, 18 w, h = 32, 18
path, original = self._get_reference_png(w, h, grayscale=grayscale) src, original = self.get_src(w, h, grayscale=grayscale)
expected = original[None, ...] expected = original[None, ...]
self._test_png(path, expected) self._test_png(src, expected)
@parameterized.expand( @parameterized.expand(
[ [
...@@ -453,10 +506,10 @@ class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase): ...@@ -453,10 +506,10 @@ class StreamReaderImageTest(TempDirMixin, TorchaudioTestCase):
) )
def test_png_effect(self, filter_desc, index): def test_png_effect(self, filter_desc, index):
h, w = 111, 250 h, w = 111, 250
path, original = self._get_reference_png(w, h, grayscale=False) src, original = self.get_src(w, h, grayscale=False)
expected = torch.flip(original, dims=(index,))[None, ...] expected = torch.flip(original, dims=(index,))[None, ...]
s = StreamReader(path) s = StreamReader(src)
s.add_video_stream(frames_per_chunk=-1, filter_desc=filter_desc) s.add_video_stream(frames_per_chunk=-1, filter_desc=filter_desc)
s.process_all_packets() s.process_all_packets()
output = s.pop_chunks()[0] output = s.pop_chunks()[0]
......
...@@ -57,7 +57,12 @@ def get_ext_modules(): ...@@ -57,7 +57,12 @@ def get_ext_modules():
] ]
) )
if _USE_FFMPEG: if _USE_FFMPEG:
modules.append(Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[])) modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio._torchaudio_ffmpeg", sources=[]),
]
)
return modules return modules
......
# the following line is added in order to export symbols when building on Windows
# this approach has some limitations as documented in https://github.com/pytorch/pytorch/pull/3650
if (MSVC)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
################################################################################ ################################################################################
# libtorchaudio # libtorchaudio
################################################################################ ################################################################################
...@@ -204,11 +210,11 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -204,11 +210,11 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
set(ADDITIONAL_ITEMS Python3::Python) set(ADDITIONAL_ITEMS Python3::Python)
endif() endif()
function(define_extension name sources libraries definitions) function(define_extension name sources include_dirs libraries definitions)
add_library(${name} SHARED ${sources}) add_library(${name} SHARED ${sources})
target_compile_definitions(${name} PRIVATE "${definitions}") target_compile_definitions(${name} PRIVATE "${definitions}")
target_include_directories( target_include_directories(
${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR}) ${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR} ${include_dirs})
target_link_libraries( target_link_libraries(
${name} ${name}
${libraries} ${libraries}
...@@ -254,6 +260,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -254,6 +260,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
define_extension( define_extension(
_torchaudio _torchaudio
"${EXTENSION_SOURCES}" "${EXTENSION_SOURCES}"
""
libtorchaudio libtorchaudio
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}" "${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
) )
...@@ -265,8 +272,23 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -265,8 +272,23 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
define_extension( define_extension(
_torchaudio_decoder _torchaudio_decoder
"${DECODER_EXTENSION_SOURCES}" "${DECODER_EXTENSION_SOURCES}"
""
"libtorchaudio_decoder" "libtorchaudio_decoder"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}" "${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
) )
endif() endif()
if(USE_FFMPEG)
set(
FFMPEG_EXTENSION_SOURCES
ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp
)
define_extension(
_torchaudio_ffmpeg
"${FFMPEG_EXTENSION_SOURCES}"
"${FFMPEG_INCLUDE_DIRS}"
"libtorchaudio_ffmpeg"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
endif() endif()
...@@ -66,17 +66,24 @@ std::string join(std::vector<std::string> vars) { ...@@ -66,17 +66,24 @@ std::string join(std::vector<std::string> vars) {
AVFormatContextPtr get_input_format_context( AVFormatContextPtr get_input_format_context(
const std::string& src, const std::string& src,
const c10::optional<std::string>& device, const c10::optional<std::string>& device,
const OptionDict& option) { const OptionDict& option,
AVFormatContext* pFormat = NULL; AVIOContext* io_ctx) {
AVFormatContext* pFormat = avformat_alloc_context();
if (!pFormat) {
throw std::runtime_error("Failed to allocate AVFormatContext.");
}
if (io_ctx) {
pFormat->pb = io_ctx;
}
AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* { auto* pInput = [&]() -> AVINPUT_FORMAT_CONST AVInputFormat* {
if (device.has_value()) { if (device.has_value()) {
std::string device_str = device.value(); std::string device_str = device.value();
AVINPUT_FORMAT_CONST AVInputFormat* p = AVINPUT_FORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str()); av_find_input_format(device_str.c_str());
if (!p) { if (!p) {
std::ostringstream msg; std::ostringstream msg;
msg << "Unsupported device: \"" << device_str << "\""; msg << "Unsupported device/format: \"" << device_str << "\"";
throw std::runtime_error(msg.str()); throw std::runtime_error(msg.str());
} }
return p; return p;
...@@ -103,6 +110,17 @@ AVFormatContextPtr get_input_format_context( ...@@ -103,6 +110,17 @@ AVFormatContextPtr get_input_format_context(
AVFormatContextPtr::AVFormatContextPtr(AVFormatContext* p) AVFormatContextPtr::AVFormatContextPtr(AVFormatContext* p)
: Wrapper<AVFormatContext, AVFormatContextDeleter>(p) {} : Wrapper<AVFormatContext, AVFormatContextDeleter>(p) {}
////////////////////////////////////////////////////////////////////////////////
// AVIO
////////////////////////////////////////////////////////////////////////////////
void AVIOContextDeleter::operator()(AVIOContext* p) {
av_freep(&p->buffer);
av_freep(&p);
};
AVIOContextPtr::AVIOContextPtr(AVIOContext* p)
: Wrapper<AVIOContext, AVIOContextDeleter>(p) {}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket // AVPacket
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -13,6 +13,7 @@ extern "C" { ...@@ -13,6 +13,7 @@ extern "C" {
#include <libavfilter/buffersink.h> #include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h> #include <libavfilter/buffersrc.h>
#include <libavformat/avformat.h> #include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/avutil.h> #include <libavutil/avutil.h>
#include <libavutil/frame.h> #include <libavutil/frame.h>
#include <libavutil/imgutils.h> #include <libavutil/imgutils.h>
...@@ -74,7 +75,19 @@ struct AVFormatContextPtr ...@@ -74,7 +75,19 @@ struct AVFormatContextPtr
AVFormatContextPtr get_input_format_context( AVFormatContextPtr get_input_format_context(
const std::string& src, const std::string& src,
const c10::optional<std::string>& device, const c10::optional<std::string>& device,
const OptionDict& option); const OptionDict& option,
AVIOContext* io_ctx = nullptr);
////////////////////////////////////////////////////////////////////////////////
// AVIO
////////////////////////////////////////////////////////////////////////////////
struct AVIOContextDeleter {
void operator()(AVIOContext* p);
};
struct AVIOContextPtr : public Wrapper<AVIOContext, AVIOContextDeleter> {
explicit AVIOContextPtr(AVIOContext* p);
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// AVPacket // AVPacket
......
...@@ -46,84 +46,70 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -46,84 +46,70 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
av_log_set_level(AV_LOG_ERROR); av_log_set_level(AV_LOG_ERROR);
}); });
m.def("torchaudio::ffmpeg_load", load); m.def("torchaudio::ffmpeg_load", load);
m.class_<StreamReaderBinding>("ffmpeg_Streamer"); m.class_<StreamReaderBinding>("ffmpeg_Streamer")
m.def("torchaudio::ffmpeg_streamer_init", init); .def(torch::init<>(init))
m.def("torchaudio::ffmpeg_streamer_num_src_streams", [](S s) { .def("num_src_streams", [](S self) { return self->num_src_streams(); })
return s->num_src_streams(); .def("num_out_streams", [](S self) { return self->num_out_streams(); })
}); .def(
m.def("torchaudio::ffmpeg_streamer_num_out_streams", [](S s) { "get_src_stream_info",
return s->num_out_streams(); [](S s, int64_t i) { return s->get_src_stream_info(i); })
}); .def(
m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", [](S s, int64_t i) { "get_out_stream_info",
return s->get_src_stream_info(i); [](S s, int64_t i) { return s->get_out_stream_info(i); })
}); .def(
m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", [](S s, int64_t i) { "find_best_audio_stream",
return s->get_out_stream_info(i); [](S s) { return s->find_best_audio_stream(); })
}); .def(
m.def("torchaudio::ffmpeg_streamer_find_best_audio_stream", [](S s) { "find_best_video_stream",
return s->find_best_audio_stream(); [](S s) { return s->find_best_video_stream(); })
}); .def("seek", [](S s, double t) { return s->seek(t); })
m.def("torchaudio::ffmpeg_streamer_find_best_video_stream", [](S s) { .def(
return s->find_best_video_stream(); "add_audio_stream",
}); [](S s,
m.def("torchaudio::ffmpeg_streamer_seek", [](S s, double t) { int64_t i,
return s->seek(t); int64_t frames_per_chunk,
}); int64_t num_chunks,
m.def( const c10::optional<std::string>& filter_desc,
"torchaudio::ffmpeg_streamer_add_audio_stream", const c10::optional<std::string>& decoder,
[](S s, const c10::optional<c10::Dict<std::string, std::string>>&
int64_t i, decoder_options) {
int64_t frames_per_chunk, s->add_audio_stream(
int64_t num_chunks, i,
const c10::optional<std::string>& filter_desc, frames_per_chunk,
const c10::optional<std::string>& decoder, num_chunks,
const c10::optional<c10::Dict<std::string, std::string>>& filter_desc,
decoder_options) { decoder,
s->add_audio_stream( map(decoder_options));
i, })
frames_per_chunk, .def(
num_chunks, "add_video_stream",
filter_desc, [](S s,
decoder, int64_t i,
map(decoder_options)); int64_t frames_per_chunk,
}); int64_t num_chunks,
m.def( const c10::optional<std::string>& filter_desc,
"torchaudio::ffmpeg_streamer_add_video_stream", const c10::optional<std::string>& decoder,
[](S s, const c10::optional<c10::Dict<std::string, std::string>>&
int64_t i, decoder_options,
int64_t frames_per_chunk, const c10::optional<std::string>& hw_accel) {
int64_t num_chunks, s->add_video_stream(
const c10::optional<std::string>& filter_desc, i,
const c10::optional<std::string>& decoder, frames_per_chunk,
const c10::optional<c10::Dict<std::string, std::string>>& num_chunks,
decoder_options, filter_desc,
const c10::optional<std::string>& hw_accel) { decoder,
s->add_video_stream( map(decoder_options),
i, hw_accel);
frames_per_chunk, })
num_chunks, .def("remove_stream", [](S s, int64_t i) { s->remove_stream(i); })
filter_desc, .def(
decoder, "process_packet",
map(decoder_options), [](S s, const c10::optional<double>& timeout, const double backoff) {
hw_accel); return s->process_packet(timeout, backoff);
}); })
m.def("torchaudio::ffmpeg_streamer_remove_stream", [](S s, int64_t i) { .def("process_all_packets", [](S s) { s->process_all_packets(); })
s->remove_stream(i); .def("is_buffer_ready", [](S s) { return s->is_buffer_ready(); })
}); .def("pop_chunks", [](S s) { return s->pop_chunks(); });
m.def(
"torchaudio::ffmpeg_streamer_process_packet",
[](S s, const c10::optional<double>& timeout, const double backoff) {
return s->process_packet(timeout, backoff);
});
m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) {
s->process_all_packets();
});
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", [](S s) {
return s->is_buffer_ready();
});
m.def("torchaudio::ffmpeg_streamer_pop_chunks", [](S s) {
return s->pop_chunks();
});
} }
} // namespace } // namespace
......
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj")
.def(py::init<
py::object,
const c10::optional<std::string>&,
const c10::optional<OptionDict>&,
int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams)
.def(
"find_best_audio_stream",
&StreamReaderFileObj::find_best_audio_stream)
.def(
"find_best_video_stream",
&StreamReaderFileObj::find_best_video_stream)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info)
.def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
.def("seek", &StreamReaderFileObj::seek)
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
.def("add_video_stream", &StreamReaderFileObj::add_video_stream)
.def("remove_stream", &StreamReaderFileObj::remove_stream)
.def("process_packet", &StreamReaderFileObj::process_packet)
.def("process_all_packets", &StreamReaderFileObj::process_all_packets)
.def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
.def("pop_chunks", &StreamReaderFileObj::pop_chunks);
}
} // namespace
} // namespace ffmpeg
} // namespace torchaudio
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += chunk_len;
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}
// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}
StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size)),
StreamReaderBinding(get_input_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
option.value_or(OptionDict{}),
pAVIO)) {}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};
// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
class StreamReaderFileObj : protected FileObj, public StreamReaderBinding {
public:
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
int64_t buffer_size);
};
} // namespace ffmpeg
} // namespace torchaudio
...@@ -21,12 +21,12 @@ void Streamer::validate_open_stream() const { ...@@ -21,12 +21,12 @@ void Streamer::validate_open_stream() const {
void Streamer::validate_src_stream_index(int i) const { void Streamer::validate_src_stream_index(int i) const {
validate_open_stream(); validate_open_stream();
if (i < 0 || i >= static_cast<int>(pFormatContext->nb_streams)) if (i < 0 || i >= static_cast<int>(pFormatContext->nb_streams))
throw std::out_of_range("Source stream index out of range"); throw std::runtime_error("Source stream index out of range");
} }
void Streamer::validate_output_stream_index(int i) const { void Streamer::validate_output_stream_index(int i) const {
if (i < 0 || i >= static_cast<int>(stream_indices.size())) if (i < 0 || i >= static_cast<int>(stream_indices.size()))
throw std::out_of_range("Output stream index out of range"); throw std::runtime_error("Output stream index out of range");
} }
void Streamer::validate_src_stream_type(int i, AVMediaType type) { void Streamer::validate_src_stream_type(int i, AVMediaType type) {
...@@ -81,19 +81,25 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const { ...@@ -81,19 +81,25 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const {
ret.codec_long_name = desc->long_name; ret.codec_long_name = desc->long_name;
} }
switch (codecpar->codec_type) { switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO: case AVMEDIA_TYPE_AUDIO: {
ret.fmt_name = AVSampleFormat smp_fmt = static_cast<AVSampleFormat>(codecpar->format);
av_get_sample_fmt_name(static_cast<AVSampleFormat>(codecpar->format)); if (smp_fmt != AV_SAMPLE_FMT_NONE) {
ret.fmt_name = av_get_sample_fmt_name(smp_fmt);
}
ret.sample_rate = static_cast<double>(codecpar->sample_rate); ret.sample_rate = static_cast<double>(codecpar->sample_rate);
ret.num_channels = codecpar->channels; ret.num_channels = codecpar->channels;
break; break;
case AVMEDIA_TYPE_VIDEO: }
ret.fmt_name = case AVMEDIA_TYPE_VIDEO: {
av_get_pix_fmt_name(static_cast<AVPixelFormat>(codecpar->format)); AVPixelFormat pix_fmt = static_cast<AVPixelFormat>(codecpar->format);
if (pix_fmt != AV_PIX_FMT_NONE) {
ret.fmt_name = av_get_pix_fmt_name(pix_fmt);
}
ret.width = codecpar->width; ret.width = codecpar->width;
ret.height = codecpar->height; ret.height = codecpar->height;
ret.frame_rate = av_q2d(stream->r_frame_rate); ret.frame_rate = av_q2d(stream->r_frame_rate);
break; break;
}
default:; default:;
} }
return ret; return ret;
...@@ -137,7 +143,7 @@ bool Streamer::is_buffer_ready() const { ...@@ -137,7 +143,7 @@ bool Streamer::is_buffer_ready() const {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) { void Streamer::seek(double timestamp) {
if (timestamp < 0) { if (timestamp < 0) {
throw std::invalid_argument("timestamp must be non-negative."); throw std::runtime_error("timestamp must be non-negative.");
} }
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE); int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
...@@ -220,6 +226,13 @@ void Streamer::add_stream( ...@@ -220,6 +226,13 @@ void Streamer::add_stream(
validate_src_stream_type(i, media_type); validate_src_stream_type(i, media_type);
AVStream* stream = pFormatContext->streams[i]; AVStream* stream = pFormatContext->streams[i];
// When media source is file-like object, it is possible that source codec is
// not detected properly.
if (stream->codecpar->format == -1) {
throw std::runtime_error(
"Failed to detect the source stream format. Please provide the decoder to use.");
}
stream->discard = AVDISCARD_DEFAULT; stream->discard = AVDISCARD_DEFAULT;
if (!processors[i]) if (!processors[i])
processors[i] = std::make_unique<StreamProcessor>( processors[i] = std::make_unique<StreamProcessor>(
......
...@@ -14,6 +14,7 @@ def _init_extension(): ...@@ -14,6 +14,7 @@ def _init_extension():
try: try:
torchaudio._extension._load_lib("libtorchaudio_ffmpeg") torchaudio._extension._load_lib("libtorchaudio_ffmpeg")
import torchaudio._torchaudio_ffmpeg
except OSError as err: except OSError as err:
raise ImportError( raise ImportError(
"Stream API requires FFmpeg libraries (libavformat and such). Please install FFmpeg 4." "Stream API requires FFmpeg libraries (libavformat and such). Please install FFmpeg 4."
......
...@@ -29,21 +29,21 @@ class StreamReaderSourceStream: ...@@ -29,21 +29,21 @@ class StreamReaderSourceStream:
Still images, such as PNG and JPEG formats are reported as `video`. Still images, such as PNG and JPEG formats are reported as `video`.
""" """
codec: str codec: str
"""Short name of the codec. Such as `pcm_s16le` and `h264`.""" """Short name of the codec. Such as ``"pcm_s16le"`` and ``"h264"``."""
codec_long_name: str codec_long_name: str
"""Detailed name of the codec. """Detailed name of the codec.
Such as `"PCM signed 16-bit little-endian"` and `"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"`. Such as "`PCM signed 16-bit little-endian`" and "`H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10`".
""" """
format: Optional[str] format: Optional[str]
"""Media format. Such as `s16` and `yuv420p`. """Media format. Such as ``"s16"`` and ``"yuv420p"``.
Commonly found audio values are; Commonly found audio values are;
- `u8`, `u8p`: Unsigned 8-bit unsigned interger. - ``"u8"``, ``"u8p"``: Unsigned 8-bit unsigned interger.
- `s16`, `s16p`: 16-bit signed integer. - ``"s16"``, ``"s16p"``: 16-bit signed integer.
- `s32`, `s32p`: 32-bit signed integer. - ``"s32"``, ``"s32p"``: 32-bit signed integer.
- `flt`, `fltp`: 32-bit floating-point. - ``"flt"``, ``"fltp"``: 32-bit floating-point.
.. note:: .. note::
...@@ -63,7 +63,7 @@ class StreamReaderSourceAudioStream(StreamReaderSourceStream): ...@@ -63,7 +63,7 @@ class StreamReaderSourceAudioStream(StreamReaderSourceStream):
The metadata of an audio source stream. The metadata of an audio source stream.
In addition to the attributes reported by :py:func:`SourceStream`, In addition to the attributes reported by :py:func:`StreamReaderSourceStream`,
when the source stream is audio type, then the following additional attributes when the source stream is audio type, then the following additional attributes
are reported. are reported.
""" """
...@@ -80,7 +80,7 @@ class StreamReaderSourceVideoStream(StreamReaderSourceStream): ...@@ -80,7 +80,7 @@ class StreamReaderSourceVideoStream(StreamReaderSourceStream):
The metadata of a video source stream. The metadata of a video source stream.
In addition to the attributes reported by :py:func:`SourceStream`, In addition to the attributes reported by :py:func:`StreamReaderSourceStream`,
when the source stream is audio type, then the following additional attributes when the source stream is audio type, then the following additional attributes
are reported. are reported.
""" """
...@@ -154,24 +154,16 @@ def _parse_oi(i): ...@@ -154,24 +154,16 @@ def _parse_oi(i):
return StreamReaderOutputStream(i[0], i[1]) return StreamReaderOutputStream(i[0], i[1])
def _get_afilter_desc(sample_rate: Optional[int], dtype: torch.dtype): def _get_afilter_desc(sample_rate: Optional[int], fmt: Optional[str]):
descs = [] descs = []
if sample_rate is not None: if sample_rate is not None:
descs.append(f"aresample={sample_rate}") descs.append(f"aresample={sample_rate}")
if dtype is not None: if fmt is not None:
fmt = {
torch.uint8: "u8p",
torch.int16: "s16p",
torch.int32: "s32p",
torch.long: "s64p",
torch.float32: "fltp",
torch.float64: "dblp",
}[dtype]
descs.append(f"aformat=sample_fmts={fmt}") descs.append(f"aformat=sample_fmts={fmt}")
return ",".join(descs) if descs else None return ",".join(descs) if descs else None
def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height: Optional[int], format: Optional[str]): def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height: Optional[int], fmt: Optional[str]):
descs = [] descs = []
if frame_rate is not None: if frame_rate is not None:
descs.append(f"fps={frame_rate}") descs.append(f"fps={frame_rate}")
...@@ -182,24 +174,107 @@ def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height: ...@@ -182,24 +174,107 @@ def _get_vfilter_desc(frame_rate: Optional[float], width: Optional[int], height:
scales.append(f"height={height}") scales.append(f"height={height}")
if scales: if scales:
descs.append(f"scale={':'.join(scales)}") descs.append(f"scale={':'.join(scales)}")
if format is not None: if fmt is not None:
fmt = {
"RGB": "rgb24",
"BGR": "bgr24",
"YUV": "yuv420p",
"GRAY": "gray",
}[format]
descs.append(f"format=pix_fmts={fmt}") descs.append(f"format=pix_fmts={fmt}")
return ",".join(descs) if descs else None return ",".join(descs) if descs else None
def _format_doc(**kwargs):
def decorator(obj):
obj.__doc__ = obj.__doc__.format(**kwargs)
return obj
return decorator
_frames_per_chunk = """Number of frames returned as one chunk.
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is."""
_buffer_chunk_size = """Internal buffer size.
When the number of chunks buffered exceeds this number, old frames are
dropped.
Default: ``3``."""
_audio_stream_index = """The source audio stream index.
If omitted, :py:attr:`default_audio_stream` is used."""
_video_stream_index = """The source video stream index.
If omitted, :py:attr:`default_video_stream` is used."""
_decoder = """The name of the decoder to be used.
When provided, use the specified decoder instead of the default one.
To list the available decoders, you can use `ffmpeg -decoders` command.
Default: ``None``."""
_decoder_option = """Options passed to decoder.
Mapping from str to str.
To list decoder options for a decoder, you can use
`ffmpeg -h decoder=<DECODER>` command.
Default: ``None``."""
_hw_accel = """Enable hardware acceleration.
When video is decoded on CUDA hardware, for example
`decode="h264_cuvid"`, passing CUDA device indicator to `hw_accel`
(i.e. `hw_accel="cuda:0"`) will place the resulting frames
directly on the specifiec CUDA device.
If `None`, the frame will be moved to CPU memory.
Default: ``None``."""
_format_audio_args = _format_doc(
frames_per_chunk=_frames_per_chunk,
buffer_chunk_size=_buffer_chunk_size,
stream_index=_audio_stream_index,
decoder=_decoder,
decoder_option=_decoder_option,
)
_format_video_args = _format_doc(
frames_per_chunk=_frames_per_chunk,
buffer_chunk_size=_buffer_chunk_size,
stream_index=_video_stream_index,
decoder=_decoder,
decoder_option=_decoder_option,
hw_accel=_hw_accel,
)
class StreamReader: class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk. """Fetch and decode audio/video streams chunk by chunk.
For the detailed usage of this class, please refer to the tutorial. For the detailed usage of this class, please refer to the tutorial.
Args: Args:
src (str): Source. Can be a file path, URL, device identifier or filter expression. src (str or file-like object): The media source.
If string-type, it must be a resource indicator that FFmpeg can
handle. This includes a file path, URL, device identifier or
filter expression. The supported value depends on the FFmpeg found
in the system.
If file-like object, it must support `read` method with the signature
`read(size: int) -> bytes`.
Additionally, if the file-like object has `seek` method, it uses
the method when parsing media metadata. This improves the reliability
of codec detection. The signagure of `seek` method must be
`seek(offset: int, whence: int) -> int`.
Please refer to the following for the expected signature and behavior
of `read` and `seek` method.
- https://docs.python.org/3/library/io.html#io.BufferedIOBase.read
- https://docs.python.org/3/library/io.html#io.IOBase.seek
format (str or None, optional): format (str or None, optional):
Override the input format, or specify the source sound device. Override the input format, or specify the source sound device.
Default: ``None`` (no override nor device input). Default: ``None`` (no override nor device input).
...@@ -232,6 +307,11 @@ class StreamReader: ...@@ -232,6 +307,11 @@ class StreamReader:
You can use this argument to change the input source before it is passed to decoder. You can use this argument to change the input source before it is passed to decoder.
Default: ``None``. Default: ``None``.
buffer_size (int):
The internal buffer size in byte. Used only when `src` is file-like object.
Default: `4096`.
""" """
def __init__( def __init__(
...@@ -239,12 +319,19 @@ class StreamReader: ...@@ -239,12 +319,19 @@ class StreamReader:
src: str, src: str,
format: Optional[str] = None, format: Optional[str] = None,
option: Optional[Dict[str, str]] = None, option: Optional[Dict[str, str]] = None,
buffer_size: int = 4096,
): ):
self._s = torch.ops.torchaudio.ffmpeg_streamer_init(src, format, option) if isinstance(src, str):
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_audio_stream(self._s) self._be = torch.classes.torchaudio.ffmpeg_Streamer(src, format, option)
self._i_audio = None if i < 0 else i elif hasattr(src, "read"):
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_video_stream(self._s) self._be = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, option, buffer_size)
self._i_video = None if i < 0 else i else:
raise ValueError("`src` must be either string or file-like object.")
i = self._be.find_best_audio_stream()
self._default_audio_stream = None if i < 0 else i
i = self._be.find_best_video_stream()
self._default_video_stream = None if i < 0 else i
@property @property
def num_src_streams(self): def num_src_streams(self):
...@@ -252,7 +339,7 @@ class StreamReader: ...@@ -252,7 +339,7 @@ class StreamReader:
:type: int :type: int
""" """
return torch.ops.torchaudio.ffmpeg_streamer_num_src_streams(self._s) return self._be.num_src_streams()
@property @property
def num_out_streams(self): def num_out_streams(self):
...@@ -260,7 +347,7 @@ class StreamReader: ...@@ -260,7 +347,7 @@ class StreamReader:
:type: int :type: int
""" """
return torch.ops.torchaudio.ffmpeg_streamer_num_out_streams(self._s) return self._be.num_out_streams()
@property @property
def default_audio_stream(self): def default_audio_stream(self):
...@@ -268,7 +355,7 @@ class StreamReader: ...@@ -268,7 +355,7 @@ class StreamReader:
:type: Optional[int] :type: Optional[int]
""" """
return self._i_audio return self._default_audio_stream
@property @property
def default_video_stream(self): def default_video_stream(self):
...@@ -276,7 +363,7 @@ class StreamReader: ...@@ -276,7 +363,7 @@ class StreamReader:
:type: Optional[int] :type: Optional[int]
""" """
return self._i_video return self._default_video_stream
def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream: def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream:
"""Get the metadata of source stream """Get the metadata of source stream
...@@ -286,7 +373,7 @@ class StreamReader: ...@@ -286,7 +373,7 @@ class StreamReader:
Returns: Returns:
SourceStream SourceStream
""" """
return _parse_si(torch.ops.torchaudio.ffmpeg_streamer_get_src_stream_info(self._s, i)) return _parse_si(self._be.get_src_stream_info(i))
def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream: def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream:
"""Get the metadata of output stream """Get the metadata of output stream
...@@ -296,7 +383,7 @@ class StreamReader: ...@@ -296,7 +383,7 @@ class StreamReader:
Returns: Returns:
OutputStream OutputStream
""" """
return _parse_oi(torch.ops.torchaudio.ffmpeg_streamer_get_out_stream_info(self._s, i)) return _parse_oi(self._be.get_out_stream_info(i))
def seek(self, timestamp: float): def seek(self, timestamp: float):
"""Seek the stream to the given timestamp [second] """Seek the stream to the given timestamp [second]
...@@ -304,227 +391,196 @@ class StreamReader: ...@@ -304,227 +391,196 @@ class StreamReader:
Args: Args:
timestamp (float): Target time in second. timestamp (float): Target time in second.
""" """
torch.ops.torchaudio.ffmpeg_streamer_seek(self._s, timestamp) self._be.seek(timestamp)
@_format_audio_args
def add_basic_audio_stream( def add_basic_audio_stream(
self, self,
frames_per_chunk: int, frames_per_chunk: int,
buffer_chunk_size: int = 3, buffer_chunk_size: int = 3,
stream_index: Optional[int] = None, stream_index: Optional[int] = None,
decoder: Optional[str] = None,
decoder_option: Optional[Dict[str, str]] = None,
format: Optional[str] = "fltp",
sample_rate: Optional[int] = None, sample_rate: Optional[int] = None,
dtype: torch.dtype = torch.float32,
): ):
"""Add output audio stream """Add output audio stream
Args: Args:
frames_per_chunk (int): Number of frames returned by StreamReader as a chunk. frames_per_chunk (int): {frames_per_chunk}
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is.
buffer_chunk_size (int, optional): Internal buffer size. buffer_chunk_size (int, optional): {buffer_chunk_size}
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source audio stream index. stream_index (int or None, optional): {stream_index}
If omitted, :py:attr:`default_audio_stream` is used.
sample_rate (int or None, optional): If provided, resample the audio. decoder (str or None, optional): {decoder}
decoder_option (dict or None, optional): {decoder_option}
format (str, optional): Output sample format (precision).
If ``None``, the output chunk has dtype corresponding to
the precision of the source audio.
Otherwise, the sample is converted and the output dtype is changed
as following.
- ``"u8p"``: The output is ``torch.uint8`` type.
- ``"s16p"``: The output is ``torch.int16`` type.
- ``"s32p"``: The output is ``torch.int32`` type.
- ``"s64p"``: The output is ``torch.int64`` type.
- ``"fltp"``: The output is ``torch.float32`` type.
- ``"dblp"``: The output is ``torch.float64`` type.
dtype (torch.dtype, optional): If not ``None``, change the output sample precision. Default: ``"fltp"``.
If floating point, then the sample value range is
`[-1, 1]`. sample_rate (int or None, optional): If provided, resample the audio.
""" """
i = self.default_audio_stream if stream_index is None else stream_index self.add_audio_stream(
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s,
i,
frames_per_chunk, frames_per_chunk,
buffer_chunk_size, buffer_chunk_size,
_get_afilter_desc(sample_rate, dtype), stream_index,
None, decoder,
None, decoder_option,
_get_afilter_desc(sample_rate, format),
) )
@_format_video_args
def add_basic_video_stream( def add_basic_video_stream(
self, self,
frames_per_chunk: int, frames_per_chunk: int,
buffer_chunk_size: int = 3, buffer_chunk_size: int = 3,
stream_index: Optional[int] = None, stream_index: Optional[int] = None,
decoder: Optional[str] = None,
decoder_option: Optional[Dict[str, str]] = None,
hw_accel: Optional[str] = None,
format: Optional[str] = "rgb24",
frame_rate: Optional[int] = None, frame_rate: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
height: Optional[int] = None, height: Optional[int] = None,
format: str = "RGB",
): ):
"""Add output video stream """Add output video stream
Args: Args:
frames_per_chunk (int): Number of frames returned by StreamReader as a chunk. frames_per_chunk (int): {frames_per_chunk}
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is. buffer_chunk_size (int, optional): {buffer_chunk_size}
stream_index (int or None, optional): {stream_index}
decoder (str or None, optional): {decoder}
buffer_chunk_size (int, optional): Internal buffer size. decoder_option (dict or None, optional): {decoder_option}
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source video stream index. hw_accel (str or None, optional): {hw_accel}
If omitted, :py:attr:`default_video_stream` is used.
format (str, optional): Change the format of image channels. Valid values are,
- ``"rgb24"``: 8 bits * 3 channels (R, G, B)
- ``"bgr24"``: 8 bits * 3 channels (B, G, R)
- ``"yuv420p"``: 8 bits * 3 channels (Y, U, V)
- ``"gray"``: 8 bits * 1 channels
Default: ``"rgb24"``.
frame_rate (int or None, optional): If provided, change the frame rate. frame_rate (int or None, optional): If provided, change the frame rate.
width (int or None, optional): If provided, change the image width. Unit: Pixel. width (int or None, optional): If provided, change the image width. Unit: Pixel.
height (int or None, optional): If provided, change the image height. Unit: Pixel.
format (str, optional): Change the format of image channels. Valid values are,
- `RGB`: 8 bits * 3 channels height (int or None, optional): If provided, change the image height. Unit: Pixel.
- `BGR`: 8 bits * 3 channels
- `YUV`: 8 bits * 3 channels
- `GRAY`: 8 bits * 1 channels
""" """
i = self.default_video_stream if stream_index is None else stream_index self.add_video_stream(
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s,
i,
frames_per_chunk, frames_per_chunk,
buffer_chunk_size, buffer_chunk_size,
stream_index,
decoder,
decoder_option,
hw_accel,
_get_vfilter_desc(frame_rate, width, height, format), _get_vfilter_desc(frame_rate, width, height, format),
None,
None,
None,
) )
@_format_audio_args
def add_audio_stream( def add_audio_stream(
self, self,
frames_per_chunk: int, frames_per_chunk: int,
buffer_chunk_size: int = 3, buffer_chunk_size: int = 3,
stream_index: Optional[int] = None, stream_index: Optional[int] = None,
filter_desc: Optional[str] = None,
decoder: Optional[str] = None, decoder: Optional[str] = None,
decoder_options: Optional[Dict[str, str]] = None, decoder_option: Optional[Dict[str, str]] = None,
filter_desc: Optional[str] = None,
): ):
"""Add output audio stream """Add output audio stream
Args: Args:
frames_per_chunk (int): Number of frames returned by StreamReader as a chunk. frames_per_chunk (int): {frames_per_chunk}
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is. buffer_chunk_size (int, optional): {buffer_chunk_size}
buffer_chunk_size (int, optional): Internal buffer size. stream_index (int or None, optional): {stream_index}
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source audio stream index. decoder (str or None, optional): {decoder}
If omitted, :py:attr:`default_audio_stream` is used.
decoder_option (dict or None, optional): {decoder_option}
filter_desc (str or None, optional): Filter description. filter_desc (str or None, optional): Filter description.
The list of available filters can be found at The list of available filters can be found at
https://ffmpeg.org/ffmpeg-filters.html https://ffmpeg.org/ffmpeg-filters.html
Note that complex filters are not supported. Note that complex filters are not supported.
decoder (str or None, optional): The name of the decoder to be used.
When provided, use the specified decoder instead of the default one.
decoder_options (dict or None, optional): Options passed to decoder.
Mapping from str to str.
""" """
i = self.default_audio_stream if stream_index is None else stream_index i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( if i is None:
self._s, raise RuntimeError("There is no audio stream.")
self._be.add_audio_stream(
i, i,
frames_per_chunk, frames_per_chunk,
buffer_chunk_size, buffer_chunk_size,
filter_desc, filter_desc,
decoder, decoder,
decoder_options, decoder_option or {},
) )
@_format_video_args
def add_video_stream( def add_video_stream(
self, self,
frames_per_chunk: int, frames_per_chunk: int,
buffer_chunk_size: int = 3, buffer_chunk_size: int = 3,
stream_index: Optional[int] = None, stream_index: Optional[int] = None,
filter_desc: Optional[str] = None,
decoder: Optional[str] = None, decoder: Optional[str] = None,
decoder_options: Optional[Dict[str, str]] = None, decoder_option: Optional[Dict[str, str]] = None,
hw_accel: Optional[str] = None, hw_accel: Optional[str] = None,
filter_desc: Optional[str] = None,
): ):
"""Add output video stream """Add output video stream
Args: Args:
frames_per_chunk (int): Number of frames returned by StreamReader as a chunk. frames_per_chunk (int): {frames_per_chunk}
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is. buffer_chunk_size (int, optional): {buffer_chunk_size}
stream_index (int or None, optional): {stream_index}
decoder (str or None, optional): {decoder}
buffer_chunk_size (int): Internal buffer size. decoder_option (dict or None, optional): {decoder_option}
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source video stream index. hw_accel (str or None, optional): {hw_accel}
If omitted, :py:attr:`default_video_stream` is used.
filter_desc (str or None, optional): Filter description. filter_desc (str or None, optional): Filter description.
The list of available filters can be found at The list of available filters can be found at
https://ffmpeg.org/ffmpeg-filters.html https://ffmpeg.org/ffmpeg-filters.html
Note that complex filters are not supported. Note that complex filters are not supported.
decoder (str or None, optional): The name of the decoder to be used.
When provided, use the specified decoder instead of the default one.
decoder_options (dict or None, optional): Options passed to decoder.
Mapping from str to str.
hw_accel (str or None, optional): Enable hardware acceleration.
The valid choice is "cuda" or ``None``.
Default: ``None``. (No hardware acceleration.)
When the following conditions are met, providing `hw_accel="cuda"`
will create Tensor directly from CUDA HW decoder.
1. TorchAudio is compiled with CUDA support.
2. FFmpeg libraries linked dynamically are compiled with NVDEC support.
3. The codec is supported NVDEC by. (Currently, `"h264_cuvid"` is supported)
Example - HW decoding::
>>> # Decode video with NVDEC, create Tensor on CPU.
>>> streamer = StreamReader(src="input.mp4")
>>> streamer.add_video_stream(10, decoder="h264_cuvid", hw_accel=None)
>>>
>>> chunk, = next(streamer.stream())
>>> print(chunk.dtype)
... cpu
>>> # Decode video with NVDEC, create Tensor directly on CUDA
>>> streamer = StreamReader(src="input.mp4")
>>> streamer.add_video_stream(10, decoder="h264_cuvid", hw_accel="cuda:1")
>>>
>>> chunk, = next(streamer.stream())
>>> print(chunk.dtype)
... cuda:1
>>> # Decode and resize video with NVDEC, create Tensor directly on CUDA
>>> streamer = StreamReader(src="input.mp4")
>>> streamer.add_video_stream(
>>> 10, decoder="h264_cuvid",
>>> decoder_options={"resize": "240x360"}, hw_accel="cuda:1")
>>>
>>> chunk, = next(streamer.stream())
>>> print(chunk.dtype)
... cuda:1
""" """
i = self.default_video_stream if stream_index is None else stream_index i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( if i is None:
self._s, raise RuntimeError("There is no video stream.")
self._be.add_video_stream(
i, i,
frames_per_chunk, frames_per_chunk,
buffer_chunk_size, buffer_chunk_size,
filter_desc, filter_desc,
decoder, decoder,
decoder_options, decoder_option or {},
hw_accel, hw_accel,
) )
...@@ -534,7 +590,7 @@ class StreamReader: ...@@ -534,7 +590,7 @@ class StreamReader:
Args: Args:
i (int): Index of the output stream to be removed. i (int): Index of the output stream to be removed.
""" """
torch.ops.torchaudio.ffmpeg_streamer_remove_stream(self._s, i) self._be.remove_stream(i)
def process_packet(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int: def process_packet(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int:
"""Read the source media and process one packet. """Read the source media and process one packet.
...@@ -593,15 +649,15 @@ class StreamReader: ...@@ -593,15 +649,15 @@ class StreamReader:
flushed the pending frames. The caller should stop calling flushed the pending frames. The caller should stop calling
this method. this method.
""" """
return torch.ops.torchaudio.ffmpeg_streamer_process_packet(self._s, timeout, backoff) return self._be.process_packet(timeout, backoff)
def process_all_packets(self): def process_all_packets(self):
"""Process packets until it reaches EOF.""" """Process packets until it reaches EOF."""
torch.ops.torchaudio.ffmpeg_streamer_process_all_packets(self._s) self._be.process_all_packets()
def is_buffer_ready(self) -> bool: def is_buffer_ready(self) -> bool:
"""Returns true if all the output streams have at least one chunk filled.""" """Returns true if all the output streams have at least one chunk filled."""
return torch.ops.torchaudio.ffmpeg_streamer_is_buffer_ready(self._s) return self._be.is_buffer_ready()
def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]: def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]:
"""Pop one chunk from all the output stream buffers. """Pop one chunk from all the output stream buffers.
...@@ -611,7 +667,7 @@ class StreamReader: ...@@ -611,7 +667,7 @@ class StreamReader:
Buffer contents. Buffer contents.
If a buffer does not contain any frame, then `None` is returned instead. If a buffer does not contain any frame, then `None` is returned instead.
""" """
return torch.ops.torchaudio.ffmpeg_streamer_pop_chunks(self._s) return self._be.pop_chunks()
def _fill_buffer(self, timeout: Optional[float], backoff: float) -> int: def _fill_buffer(self, timeout: Optional[float], backoff: float) -> int:
"""Keep processing packets until all buffers have at least one chunk """Keep processing packets until all buffers have at least one chunk
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment