Commit 7a3e262d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add Streaming API (#2164)

Summary:
This PR adds the prototype streaming API.
The implementation is based on ffmpeg libraries.

For the detailed usage, please refer to [the resulting tutorial](https://534376-90321822-gh.circle-artifacts.com/0/docs/tutorials/streaming_api_tutorial.html).

Pull Request resolved: https://github.com/pytorch/audio/pull/2164

Reviewed By: hwangjeff

Differential Revision: D33934457

Pulled By: mthrok

fbshipit-source-id: 92ade4aff2d25baf02c0054682d4fbdc9ba8f3fe
parent db12d1a0
......@@ -66,7 +66,7 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
......
......@@ -57,7 +57,7 @@ fi
(
set -x
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} 'librosa>=0.8.0' parameterized 'requests>=2.20'
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect
pip install kaldi-io SoundFile coverage pytest pytest-cov scipy transformers expecttest unidecode inflect Pillow
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
......
......@@ -49,13 +49,19 @@ torchaudio.
conda install pytorch -c pytorch-nightly
```
### Install Torchaudio
### Install build dependencies
```bash
# Install build-time dependencies
pip install cmake ninja pkgconfig
pip install cmake ninja
# [optional for sox]
conda install pkg-config
# [optional for ffmpeg]
conda install ffmpeg
```
### Install Torchaudio
```bash
# Build torchaudio
git clone https://github.com/pytorch/audio.git
......@@ -68,6 +74,13 @@ python setup.py develop
Some environmnet variables that change the build behavior
- `BUILD_SOX`: Deteremines whether build and bind libsox in non-Windows environments. (no effect in Windows as libsox integration is not available) Default value is 1 (build and bind). Use 0 for disabling it.
- `USE_CUDA`: Determines whether build the custom CUDA kernel. Default to the availability of CUDA-compatible GPUs.
- `BUILD_KALDI`: Determines whether build Kaldi extension. This is required for `kaldi_pitch` function. Default value is 1 on Linux/macOS and 0 on Windows.
- `BUILD_RNNT`: Determines whether build RNN-T loss function. Default value is 1.
- `BUILD_CTC_DECODER`: Determines whether build decoder features based on KenLM and FlashLight CTC decoder. Default value is 1.
Please check the [./tools/setup_helpers/extension.py](./tools/setup_helpers/extension.py) for the up-to-date detail.
### Running Test
If you built sox, set the `PATH` variable so that the tests properly use the newly built `sox` binary:
......@@ -92,6 +105,7 @@ Optional packages to install if you want to run related tests:
source. Commit `e6eddd80` is known to work.)
- `unidecode` (dependency for testing text preprocessing functions for examples/pipeline_tacotron2)
- `inflect` (dependency for testing text preprocessing functions for examples/pipeline_tacotron2)
- `Pillow` (dependency for testing ffmpeg image processing)
## Development Process
......
......@@ -53,6 +53,8 @@ extensions = [
"sphinx_gallery.gen_gallery",
]
autodoc_member_order = "bysource"
# katex options
#
#
......
......@@ -57,6 +57,7 @@ Prototype API References
:caption: Prototype API Reference
prototype
prototype.io
prototype.ctc_decoder
Getting Started
......
torchaudio.prototype.io
=======================
.. currentmodule:: torchaudio.prototype.io
SourceStream
------------
.. autoclass:: SourceStream
:members:
SourceAudioStream
-----------------
.. autoclass:: SourceAudioStream
:members:
SourceVideoStream
-----------------
.. autoclass:: SourceVideoStream
:members:
OutputStream
------------
.. autoclass:: OutputStream
:members:
Streamer
--------
.. autoclass:: Streamer
:members:
......@@ -17,4 +17,5 @@ imported explicitly, e.g.
import torchaudio.prototype.ctc_decoder
.. toctree::
prototype.io
prototype.ctc_decoder
......@@ -25,6 +25,10 @@ from .data_utils import (
get_spectrogram,
)
from .func_utils import torch_script
from .image_utils import (
save_image,
get_image,
)
from .parameterized_utils import load_params, nested_params
from .wav_utils import (
get_wav_data,
......@@ -62,4 +66,6 @@ __all__ = [
"load_params",
"nested_params",
"torch_script",
"save_image",
"get_image",
]
import torch
from torchaudio._internal.module_utils import is_module_available
if is_module_available("PIL"):
from PIL import Image
def save_image(path, data, mode=None):
"""Save image.
The input image is expected to be CHW order
"""
if torch.is_tensor(data):
data = data.numpy()
if mode == "L" and data.ndim == 3:
assert data.shape[0] == 1
data = data[0]
if data.ndim == 3:
data = data.transpose(1, 2, 0)
Image.fromarray(data, mode=mode).save(path)
def get_image(width, height, grayscale=False):
"""Generate image Tensor, returns CHW"""
channels = 1 if grayscale else 3
numel = width * height * channels
img = torch.arange(numel, dtype=torch.int64) % 256
img = img.reshape(channels, height, width).to(torch.uint8)
return img
import torch
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
TempDirMixin,
get_asset_path,
get_wav_data,
save_wav,
skipIfNoFFmpeg,
nested_params,
is_ffmpeg_available,
get_image,
save_image,
)
if is_ffmpeg_available():
from torchaudio.prototype.io import (
Streamer,
SourceStream,
SourceVideoStream,
SourceAudioStream,
)
def get_video_asset(file="nasa_13013.mp4"):
return get_asset_path(file)
@skipIfNoFFmpeg
class StreamerInterfaceTest(TempDirMixin, TorchaudioTestCase):
"""Test suite for interface behaviors around Streamer"""
def test_streamer_invalid_input(self):
"""Streamer constructor does not segfault but raise an exception when the input is invalid"""
with self.assertRaises(RuntimeError):
Streamer("foobar")
def test_src_info(self):
"""`get_src_stream_info` properly fetches information"""
s = Streamer(get_video_asset())
assert s.num_src_streams == 6
expected = [
SourceVideoStream(
media_type="video",
codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
format="yuv420p",
bit_rate=71925,
width=320,
height=180,
frame_rate=25.0,
),
SourceAudioStream(
media_type="audio",
codec="aac",
codec_long_name="AAC (Advanced Audio Coding)",
format="fltp",
bit_rate=72093,
sample_rate=8000.0,
num_channels=2,
),
SourceStream(
media_type="subtitle",
codec="mov_text",
codec_long_name="MOV text",
format=None,
bit_rate=None,
),
SourceVideoStream(
media_type="video",
codec="h264",
codec_long_name="H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10",
format="yuv420p",
bit_rate=128783,
width=480,
height=270,
frame_rate=29.97002997002997,
),
SourceAudioStream(
media_type="audio",
codec="aac",
codec_long_name="AAC (Advanced Audio Coding)",
format="fltp",
bit_rate=128837,
sample_rate=16000.0,
num_channels=2,
),
SourceStream(
media_type="subtitle",
codec="mov_text",
codec_long_name="MOV text",
format=None,
bit_rate=None,
),
]
for i, exp in enumerate(expected):
assert exp == s.get_src_stream_info(i)
def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = Streamer(get_video_asset())
for i in [-1, 6, 7, 8]:
with self.assertRaises(IndexError):
s.get_src_stream_info(i)
def test_default_streams(self):
"""default stream is not None"""
s = Streamer(get_video_asset())
assert s.default_audio_stream is not None
assert s.default_video_stream is not None
def test_default_audio_stream_none(self):
"""default audio stream is None for video without audio"""
s = Streamer(get_video_asset("nasa_13013_no_audio.mp4"))
assert s.default_audio_stream is None
def test_default_video_stream_none(self):
"""default video stream is None for video with only audio"""
s = Streamer(get_video_asset("nasa_13013_no_video.mp4"))
assert s.default_video_stream is None
def test_num_out_stream(self):
"""num_out_streams gives the correct count of output streams"""
s = Streamer(get_video_asset())
n, m = 6, 4
for i in range(n):
assert s.num_out_streams == i
s.add_audio_stream(frames_per_chunk=-1)
for i in range(m):
assert s.num_out_streams == n - i
s.remove_stream(0)
for i in range(m):
assert s.num_out_streams == n - m + i
s.add_video_stream(frames_per_chunk=-1)
for i in range(n):
assert s.num_out_streams == n - i
s.remove_stream(n - i - 1)
assert s.num_out_streams == 0
def test_basic_audio_stream(self):
"""`add_basic_audio_stream` constructs a correct filter."""
s = Streamer(get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=None)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=torch.int16)
sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream
assert sinfo.filter_description == ""
sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_audio_stream
assert "aresample=8000" in sinfo.filter_description
sinfo = s.get_out_stream_info(2)
assert sinfo.source_index == s.default_audio_stream
assert "aformat=sample_fmts=s16" in sinfo.filter_description
def test_basic_video_stream(self):
"""`add_basic_video_stream` constructs a correct filter."""
s = Streamer(get_video_asset())
s.add_basic_video_stream(frames_per_chunk=-1, format=None)
s.add_basic_video_stream(frames_per_chunk=-1, width=3, height=5)
s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=7)
s.add_basic_video_stream(frames_per_chunk=-1, format="BGR")
sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream
assert sinfo.filter_description == ""
sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_video_stream
assert "scale=width=3:height=5" in sinfo.filter_description
sinfo = s.get_out_stream_info(2)
assert sinfo.source_index == s.default_video_stream
assert "fps=7" in sinfo.filter_description
sinfo = s.get_out_stream_info(3)
assert sinfo.source_index == s.default_video_stream
assert "format=pix_fmts=bgr24" in sinfo.filter_description
def test_remove_streams(self):
"""`remove_stream` removes the correct output stream"""
s = Streamer(get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=24000)
s.add_basic_video_stream(frames_per_chunk=-1, width=16, height=16)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
sinfo = [s.get_out_stream_info(i) for i in range(3)]
s.remove_stream(1)
del sinfo[1]
assert sinfo == [s.get_out_stream_info(i) for i in range(s.num_out_streams)]
s.remove_stream(1)
del sinfo[1]
assert sinfo == [s.get_out_stream_info(i) for i in range(s.num_out_streams)]
s.remove_stream(0)
del sinfo[0]
assert [] == [s.get_out_stream_info(i) for i in range(s.num_out_streams)]
def test_remove_stream_invalid(self):
"""Attempt to remove invalid output streams raises IndexError"""
s = Streamer(get_video_asset())
for i in range(-3, 3):
with self.assertRaises(IndexError):
s.remove_stream(i)
s.add_audio_stream(frames_per_chunk=-1)
for i in range(-3, 3):
if i == 0:
continue
with self.assertRaises(IndexError):
s.remove_stream(i)
def test_process_packet(self):
"""`process_packet` method returns 0 while there is a packet in source stream"""
s = Streamer(get_video_asset())
# nasa_1013.mp3 contains 1023 packets.
for _ in range(1023):
code = s.process_packet()
assert code == 0
# now all the packets should be processed, so process_packet returns 1.
code = s.process_packet()
assert code == 1
def test_pop_chunks_no_output_stream(self):
"""`pop_chunks` method returns empty list when there is no output stream"""
s = Streamer(get_video_asset())
assert s.pop_chunks() == []
def test_pop_chunks_empty_buffer(self):
"""`pop_chunks` method returns None when a buffer is empty"""
s = Streamer(get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1)
s.add_basic_video_stream(frames_per_chunk=-1)
assert s.pop_chunks() == [None, None]
def test_pop_chunks_exhausted_stream(self):
"""`pop_chunks` method returns None when the source stream is exhausted"""
s = Streamer(get_video_asset())
# video is 16.57 seconds.
# audio streams per 10 second chunk
# video streams per 20 second chunk
# The first `pop_chunk` call should return 2 Tensors (10 second audio and 16.57 second video)
# The second call should return 1 Tensor (6.57 second audio) and None.
# After that, `pop_chunk` should keep returning None-s.
s.add_basic_audio_stream(frames_per_chunk=100, sample_rate=10, buffer_chunk_size=3)
s.add_basic_video_stream(frames_per_chunk=200, frame_rate=10, buffer_chunk_size=3)
s.process_all_packets()
chunks = s.pop_chunks()
assert chunks[0] is not None
assert chunks[1] is not None
assert chunks[0].shape[0] == 100 # audio tensor contains 10 second chunk
assert chunks[1].shape[0] < 200 # video tensor contains less than 20 second chunk
chunks = s.pop_chunks()
assert chunks[0] is not None
assert chunks[1] is None
assert chunks[0].shape[0] < 100 # audio tensor contains less than 10 second chunk
for _ in range(10):
chunks = s.pop_chunks()
assert chunks[0] is None
assert chunks[1] is None
def test_stream_empty(self):
"""`stream` fails when no output stream is configured"""
s = Streamer(get_video_asset())
with self.assertRaises(RuntimeError):
next(s.stream())
def test_stream_smoke_test(self):
"""`stream` streams chunks fine"""
w, h = 256, 198
s = Streamer(get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h)
for i, (achunk, vchunk) in enumerate(s.stream()):
assert achunk.shape == torch.Size([2000, 2])
assert vchunk.shape == torch.Size([15, 3, h, w])
if i >= 40:
break
@skipIfNoFFmpeg
class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
"""Test suite for audio streaming"""
def _get_reference_wav(self, sample_rate, channels_first=False, **kwargs):
data = get_wav_data(**kwargs, normalize=False, channels_first=channels_first)
path = self.get_temp_path("ref.wav")
save_wav(path, data, sample_rate, channels_first=channels_first)
return path, data
def _test_wav(self, path, original, dtype):
s = Streamer(path)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=dtype)
s.process_all_packets()
(output,) = s.pop_chunks()
self.assertEqual(original, output)
@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
[1, 2, 4, 8],
)
def test_basic_audio_stream(self, dtype, num_channels):
"""`basic_audio_stream` can load WAV file properly."""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels)
# provide the matching dtype
self._test_wav(path, original, getattr(torch, dtype))
# use the internal dtype ffmpeg picks
self._test_wav(path, original, None)
@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
[1, 2, 4, 8],
)
def test_audio_stream(self, dtype, num_channels):
"""`add_audio_stream` can apply filter"""
path, original = self._get_reference_wav(8000, dtype=dtype, num_channels=num_channels)
expected = torch.flip(original, dims=(0,))
s = Streamer(path)
s.add_audio_stream(frames_per_chunk=-1, filter_desc="areverse")
s.process_all_packets()
(output,) = s.pop_chunks()
self.assertEqual(expected, output)
@nested_params(
["int16", "uint8", "int32"], # "float", "double", "int64"]
[1, 2, 4, 8],
)
def test_audio_seek(self, dtype, num_channels):
"""`seek` changes the position properly"""
path, original = self._get_reference_wav(1, dtype=dtype, num_channels=num_channels, num_frames=30)
for t in range(10, 20):
expected = original[t:, :]
s = Streamer(path)
s.add_audio_stream(frames_per_chunk=-1)
s.seek(float(t))
s.process_all_packets()
(output,) = s.pop_chunks()
self.assertEqual(expected, output)
@nested_params(
[
(18, 6, 3), # num_frames is divisible by frames_per_chunk
(18, 5, 4), # num_frames is not divisible by frames_per_chunk
(18, 32, 1), # num_frames is shorter than frames_per_chunk
],
[1, 2, 4, 8],
)
def test_audio_frames_per_chunk(self, frame_param, num_channels):
"""Different chunk parameter covers the source media properly"""
num_frames, frames_per_chunk, buffer_chunk_size = frame_param
path, original = self._get_reference_wav(
8000, dtype="int16", num_channels=num_channels, num_frames=num_frames, channels_first=False
)
s = Streamer(path)
s.add_audio_stream(frames_per_chunk=frames_per_chunk, buffer_chunk_size=buffer_chunk_size)
i, outputs = 0, []
for (output,) in s.stream():
expected = original[frames_per_chunk * i : frames_per_chunk * (i + 1), :]
outputs.append(output)
self.assertEqual(expected, output)
i += 1
assert i == num_frames // frames_per_chunk + (1 if num_frames % frames_per_chunk else 0)
self.assertEqual(torch.cat(outputs, 0), original)
@skipIfNoFFmpeg
class StreamerImageTest(TempDirMixin, TorchaudioTestCase):
def _get_reference_png(self, width: int, height: int, grayscale: bool):
original = get_image(width, height, grayscale=grayscale)
path = self.get_temp_path("ref.png")
save_image(path, original, mode="L" if grayscale else "RGB")
return path, original
def _test_png(self, path, original, format=None):
s = Streamer(path)
s.add_basic_video_stream(frames_per_chunk=-1, format=format)
s.process_all_packets()
(output,) = s.pop_chunks()
self.assertEqual(original, output)
@nested_params([True, False])
def test_png(self, grayscale):
# TODO:
# Add test with alpha channel (RGBA, ARGB, BGRA, ABGR)
w, h = 32, 18
path, original = self._get_reference_png(w, h, grayscale=grayscale)
expected = original[None, ...]
self._test_png(path, expected)
@parameterized.expand(
[
("hflip", 2),
("vflip", 1),
]
)
def test_png_effect(self, filter_desc, index):
h, w = 111, 250
path, original = self._get_reference_png(w, h, grayscale=False)
expected = torch.flip(original, dims=(index,))[None, ...]
s = Streamer(path)
s.add_video_stream(frames_per_chunk=-1, filter_desc=filter_desc)
s.process_all_packets()
output = s.pop_chunks()[0]
print("expected", expected)
print("output", output)
self.assertEqual(expected, output)
......@@ -146,6 +146,10 @@ void AudioBuffer::push_tensor(torch::Tensor t) {
// Discard older frames.
int max_frames = num_chunks * frames_per_chunk;
while (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
......@@ -215,6 +219,10 @@ void VideoBuffer::push_tensor(torch::Tensor t) {
// Trim
int max_frames = num_chunks * frames_per_chunk;
if (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
......
......@@ -29,7 +29,9 @@ AVFormatContext* get_format_context(
av_dict_free(&dict);
if (ret < 0)
throw std::runtime_error("Failed to open the input: " + src);
throw std::runtime_error(
"Failed to open the input \"" + src + "\" (" + av_err2string(ret) +
").");
return pFormat;
}
} // namespace
......
......@@ -22,6 +22,14 @@ extern "C" {
namespace torchaudio {
namespace ffmpeg {
// Replacement of av_err2str, which causes
// `error: taking address of temporary array`
// https://github.com/joncampbell123/composite-video-simulator/issues/5
av_always_inline std::string av_err2string(int errnum) {
char str[AV_ERROR_MAX_STRING_SIZE];
return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
}
// Base structure that handles memory management.
// Resource is freed by the destructor of unique_ptr,
// which will call custom delete mechanism provided via Deleter
......
......@@ -84,7 +84,9 @@ void FilterGraph::add_src(AVRational time_base, AVCodecParameters* codecpar) {
int ret = avfilter_graph_create_filter(
&buffersrc_ctx, buffersrc, "in", args.c_str(), NULL, pFilterGraph);
if (ret < 0) {
throw std::runtime_error("Failed to create input filter: \"" + args + "\"");
throw std::runtime_error(
"Failed to create input filter: \"" + args + "\" (" +
av_err2string(ret) + ")");
}
}
......@@ -160,7 +162,9 @@ void FilterGraph::add_process() {
avfilter_graph_parse_ptr(pFilterGraph, desc.c_str(), out, in, nullptr);
if (ret < 0) {
throw std::runtime_error("Failed to create the filter.");
throw std::runtime_error(
"Failed to create the filter from \"" + desc + "\" (" +
av_err2string(ret) + ".)");
}
}
......
......@@ -99,7 +99,7 @@ int64_t find_best_video_stream(S s) {
return s->s.find_best_video_stream();
}
void seek(S s, int64_t timestamp) {
void seek(S s, double timestamp) {
s->s.seek(timestamp);
}
......@@ -256,12 +256,20 @@ void remove_stream(S s, int64_t i) {
s->s.remove_stream(i);
}
int64_t process_packet(S s) {
return s->s.process_packet();
int64_t process_packet(Streamer& s) {
int64_t code = s.process_packet();
if (code < 0) {
throw std::runtime_error(
"Failed to process a packet. (" + av_err2string(code) + "). ");
}
return code;
}
int64_t process_all_packets(S s) {
return s->s.process_all_packets();
void process_all_packets(Streamer& s) {
int ret = 0;
do {
ret = process_packet(s);
} while (!ret);
}
bool is_buffer_ready(S s) {
......@@ -278,7 +286,7 @@ std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) {
auto sinfo = s.get_src_stream_info(i);
int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate);
s.add_audio_stream(i, -1, -1, "");
s.process_all_packets();
process_all_packets(s);
auto tensors = s.pop_chunks();
return std::make_tuple<>(tensors[0], sample_rate);
}
......@@ -312,8 +320,12 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::ffmpeg_streamer_add_audio_stream", add_audio_stream);
m.def("torchaudio::ffmpeg_streamer_add_video_stream", add_video_stream);
m.def("torchaudio::ffmpeg_streamer_remove_stream", remove_stream);
m.def("torchaudio::ffmpeg_streamer_process_packet", process_packet);
m.def("torchaudio::ffmpeg_streamer_process_all_packets", process_all_packets);
m.def("torchaudio::ffmpeg_streamer_process_packet", [](S s) {
return process_packet(s->s);
});
m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) {
return process_all_packets(s->s);
});
m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", is_buffer_ready);
m.def("torchaudio::ffmpeg_streamer_pop_chunks", pop_chunks);
}
......
......@@ -137,14 +137,7 @@ void Streamer::seek(double timestamp) {
int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
if (ret < 0) {
// Temporarily removing `av_err2str` function as it causes
// `error: taking address of temporary array` on GCC.
// TODO:
// Workaround with `av_err2string` function from
// https://github.com/joncampbell123/composite-video-simulator/issues/5#issuecomment-611885908
throw std::runtime_error(std::string("Failed to seek."));
// throw std::runtime_error(std::string("Failed to seek: ") +
// av_err2str(ret));
throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)");
}
}
......@@ -249,14 +242,6 @@ int Streamer::drain() {
return ret;
}
int Streamer::process_all_packets() {
int ret = 0;
do {
ret = process_packet();
} while (!ret);
return ret;
}
std::vector<c10::optional<torch::Tensor>> Streamer::pop_chunks() {
std::vector<c10::optional<torch::Tensor>> ret;
for (auto& i : stream_indices) {
......
......@@ -87,7 +87,6 @@ class Streamer {
// Stream methods
//////////////////////////////////////////////////////////////////////////////
int process_packet();
int process_all_packets();
int drain();
......
......@@ -8,9 +8,9 @@ namespace ffmpeg {
struct SrcStreamInfo {
AVMediaType media_type;
const char* codec_name = NULL;
const char* codec_long_name = NULL;
const char* fmt_name = NULL;
const char* codec_name = "N/A";
const char* codec_long_name = "N/A";
const char* fmt_name = "N/A";
int bit_rate = 0;
// Audio
double sample_rate = 0;
......
import torch
import torchaudio
torchaudio._extension._load_lib("libtorchaudio_ffmpeg")
torch.ops.torchaudio.ffmpeg_init()
from .streamer import (
Streamer,
SourceStream,
SourceAudioStream,
SourceVideoStream,
OutputStream,
)
__all__ = [
"Streamer",
"SourceStream",
"SourceAudioStream",
"SourceVideoStream",
"OutputStream",
]
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Iterator
import torch
import torchaudio
@dataclass
class SourceStream:
"""SourceStream()
The metadata of a source stream. This class is used when representing streams of
media type other than `audio` or `video`.
When source stream is `audio` or `video` type, :py:class:`SourceAudioStream` and
:py:class:`SourceVideoStream`, which reports additional media-specific attributes,
are used respectively.
"""
media_type: str
"""The type of the stream.
One of `audio`, `video`, `data`, `subtitle`, `attachment` and empty string.
.. note::
Only `audio` and `video` streams are supported for output.
.. note::
Still images, such as PNG and JPEG formats are reported as `video`.
"""
codec: str
"""Short name of the codec. Such as `pcm_s16le` and `h264`."""
codec_long_name: str
"""Detailed name of the codec.
Such as `"PCM signed 16-bit little-endian"` and `"H.264 / AVC / MPEG-4 AVC / MPEG-4 part 10"`.
"""
format: Optional[str]
"""Media format. Such as `s16` and `yuv420p`.
Commonly found audio values are;
- `u8`, `u8p`: Unsigned 8-bit unsigned interger.
- `s16`, `s16p`: 16-bit signed integer.
- `s32`, `s32p`: 32-bit signed integer.
- `flt`, `fltp`: 32-bit floating-point.
.. note::
`p` at the end indicates the format is `planar`.
Channels are grouped together instead of interspersed in memory.
"""
bit_rate: Optional[int]
"""Bit rate of the stream in bits-per-second.
This is an estimated values based on the initial few frames of the stream.
For container formats and variable bit rate, it can be 0.
"""
@dataclass
class SourceAudioStream(SourceStream):
"""SourceAudioStream()
The metadata of an audio source stream.
In addition to the attributes reported by :py:func:`SourceStream`,
when the source stream is audio type, then the following additional attributes
are reported.
"""
sample_rate: float
"""Sample rate of the audio."""
num_channels: int
"""Number of channels."""
@dataclass
class SourceVideoStream(SourceStream):
"""SourceVideoStream()
The metadata of a video source stream.
In addition to the attributes reported by :py:func:`SourceStream`,
when the source stream is audio type, then the following additional attributes
are reported.
"""
width: int
"""Width of the video frame in pixel."""
height: int
"""Height of the video frame in pixel."""
frame_rate: float
"""Frame rate."""
# Indices of SrcInfo returned by low-level `get_src_stream_info`
# - COMMON
_MEDIA_TYPE = 0
_CODEC = 1
_CODEC_LONG = 2
_FORMAT = 3
_BIT_RATE = 4
# - AUDIO
_SAMPLE_RATE = 5
_NUM_CHANNELS = 6
# - VIDEO
_WIDTH = 7
_HEIGHT = 8
_FRAME_RATE = 9
def _parse_si(i):
media_type = i[_MEDIA_TYPE]
codec_name = i[_CODEC]
codec_long_name = i[_CODEC_LONG]
if media_type == "audio":
return SourceAudioStream(
media_type,
codec_name,
codec_long_name,
i[_FORMAT],
i[_BIT_RATE],
i[_SAMPLE_RATE],
i[_NUM_CHANNELS],
)
if media_type == "video":
return SourceVideoStream(
media_type,
codec_name,
codec_long_name,
i[_FORMAT],
i[_BIT_RATE],
i[_WIDTH],
i[_HEIGHT],
i[_FRAME_RATE],
)
return SourceStream(media_type, codec_name, codec_long_name, None, None)
@dataclass
class OutputStream:
"""OutputStream()
Output stream configured on :py:class:`Streamer`.
"""
source_index: int
"""Index of the source stream that this output stream is connected."""
filter_description: str
"""Description of filter graph applied to the source stream."""
def _parse_oi(i):
return OutputStream(i[0], i[1])
class Streamer:
"""Fetch and decode audio/video streams chunk by chunk.
For the detailed usage of this class, please refer to the tutorial.
Args:
src (str): Source. Can be a file path, URL, device identifier or filter expression.
format (str or None, optional):
Override the input format, or specify the source sound device.
Default: ``None`` (no override nor device input).
This argument serves two different usecases.
1) Override the source format.
This is useful when the input data do not contain a header.
2) Specify the input source device.
This allows to load media stream from hardware devices,
such as microphone, camera and screen, or a virtual device.
.. note::
This option roughly corresponds to ``-f`` option of ``ffmpeg`` command.
Please refer to the ffmpeg documentations for the possible values.
https://ffmpeg.org/ffmpeg-formats.html
For device access, the available values vary based on hardware (AV device) and
software configuration (ffmpeg build).
https://ffmpeg.org/ffmpeg-devices.html
option (dict of str to str, optional):
Custom option passed when initializing format context (opening source).
You can use this argument to change the input source before it is passed to decoder.
Default: ``None``.
"""
def __init__(
self,
src: str,
format: Optional[str] = None,
option: Optional[Dict[str, str]] = None,
):
self._s = torch.ops.torchaudio.ffmpeg_streamer_init(src, format, option)
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_audio_stream(self._s)
self._i_audio = None if i < 0 else i
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_video_stream(self._s)
self._i_video = None if i < 0 else i
@property
def num_src_streams(self):
"""Number of streams found in the provided media source.
:type: int
"""
return torch.ops.torchaudio.ffmpeg_streamer_num_src_streams(self._s)
@property
def num_out_streams(self):
"""Number of output streams configured by client code.
:type: int
"""
return torch.ops.torchaudio.ffmpeg_streamer_num_out_streams(self._s)
@property
def default_audio_stream(self):
"""The index of default audio stream. ``None`` if there is no audio stream
:type: Optional[int]
"""
return self._i_audio
@property
def default_video_stream(self):
"""The index of default video stream. ``None`` if there is no video stream
:type: Optional[int]
"""
return self._i_video
def get_src_stream_info(self, i: int) -> torchaudio.prototype.io.SourceStream:
"""Get the metadata of source stream
Args:
i (int): Stream index.
Returns:
SourceStream
"""
return _parse_si(torch.ops.torchaudio.ffmpeg_streamer_get_src_stream_info(self._s, i))
def get_out_stream_info(self, i: int) -> torchaudio.prototype.io.OutputStream:
"""Get the metadata of output stream
Args:
i (int): Stream index.
Returns:
OutputStream
"""
return _parse_oi(torch.ops.torchaudio.ffmpeg_streamer_get_out_stream_info(self._s, i))
def seek(self, timestamp: float):
"""Seek the stream to the given timestamp [second]
Args:
timestamp (float): Target time in second.
"""
torch.ops.torchaudio.ffmpeg_streamer_seek(self._s, timestamp)
def add_basic_audio_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int = 3,
stream_index: Optional[int] = None,
sample_rate: Optional[int] = None,
dtype: torch.dtype = torch.float32,
):
"""Add output audio stream
Args:
frames_per_chunk (int): Number of frames returned by Streamer as a chunk.
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is.
buffer_chunk_size (int, optional): Internal buffer size.
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source audio stream index.
If omitted, :py:attr:`default_audio_stream` is used.
sample_rate (int or None, optional): If provided, resample the audio.
dtype (torch.dtype, optional): If not ``None``, change the output sample precision.
If floating point, then the sample value range is
`[-1, 1]`.
"""
i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_basic_audio_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, sample_rate, dtype
)
def add_basic_video_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int = 3,
stream_index: Optional[int] = None,
frame_rate: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
format: str = "RGB",
):
"""Add output video stream
Args:
frames_per_chunk (int): Number of frames returned by Streamer as a chunk.
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is.
buffer_chunk_size (int, optional): Internal buffer size.
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source video stream index.
If omitted, :py:attr:`default_video_stream` is used.
frame_rate (int or None, optional): If provided, change the frame rate.
width (int or None, optional): If provided, change the image width. Unit: Pixel.
height (int or None, optional): If provided, change the image height. Unit: Pixel.
format (str, optional): Change the format of image channels. Valid values are,
- `RGB`: 8 bits * 3 channels
- `BGR`: 8 bits * 3 channels
- `GRAY`: 8 bits * 1 channels
"""
i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_basic_video_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
frame_rate,
width,
height,
format,
)
def add_audio_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int = 3,
stream_index: Optional[int] = None,
filter_desc: Optional[str] = None,
):
"""Add output audio stream
Args:
frames_per_chunk (int): Number of frames returned by Streamer as a chunk.
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is.
buffer_chunk_size (int, optional): Internal buffer size.
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source audio stream index.
If omitted, :py:attr:`default_audio_stream` is used.
filter_desc (str or None, optional): Filter description.
The list of available filters can be found at
https://ffmpeg.org/ffmpeg-filters.html
Note that complex filters are not supported.
"""
i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, filter_desc
)
def add_video_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int = 3,
stream_index: Optional[int] = None,
filter_desc: Optional[str] = None,
):
"""Add output video stream
Args:
frames_per_chunk (int): Number of frames returned by Streamer as a chunk.
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is.
buffer_chunk_size (int): Internal buffer size.
When this many chunks are created, but
client code does not pop the chunk, if a new frame comes in, the old
chunk is dropped.
stream_index (int or None, optional): The source video stream index.
If omitted, :py:attr:`default_video_stream` is used.
filter_desc (str or None, optional): Filter description.
The list of available filters can be found at
https://ffmpeg.org/ffmpeg-filters.html
Note that complex filters are not supported.
"""
i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s, i, frames_per_chunk, buffer_chunk_size, filter_desc
)
def remove_stream(self, i: int):
"""Remove an output stream.
Args:
i (int): Index of the output stream to be removed.
"""
torch.ops.torchaudio.ffmpeg_streamer_remove_stream(self._s, i)
def process_packet(self) -> int:
"""Read the source media and process one packet.
The data in the packet will be decoded and passed to corresponding
output stream processors.
If the packet belongs to a source stream that is not connected to
an output stream, then the data are discarded.
When the source reaches EOF, then it triggers all the output stream
processors to enter drain mode. All the output stream processors
flush the pending frames.
Returns:
int:
``0``
A packet was processed properly. The caller can keep
calling this function to buffer more frames.
``1``
The streamer reached EOF. All the output stream processors
flushed the pending frames. The caller should stop calling
this method.
"""
return torch.ops.torchaudio.ffmpeg_streamer_process_packet(self._s)
def process_all_packets(self):
"""Process packets until it reaches EOF."""
torch.ops.torchaudio.ffmpeg_streamer_process_all_packets(self._s)
def is_buffer_ready(self) -> bool:
"""Returns true if all the output streams have at least one chunk filled."""
return torch.ops.torchaudio.ffmpeg_streamer_is_buffer_ready(self._s)
def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]:
"""Pop one chunk from all the output stream buffers.
Returns
Tuple[Optional[Tensor]]:
Buffer contents.
If a buffer does not contain any frame, then `None` is returned instead.
"""
return torch.ops.torchaudio.ffmpeg_streamer_pop_chunks(self._s)
def fill_buffer(self) -> int:
"""Keep processing packets until all buffers have at least one chunk
Returns:
int:
``0``
Packets are processed properly and buffers are
ready to be popped once.
``1``
The streamer reached EOF. All the output stream processors
flushed the pending frames. The caller should stop calling
this method.
"""
while not self.is_buffer_ready():
for _ in range(3):
code = self.process_packet()
if code != 0:
return code
return 0
def stream(self) -> Iterator[Tuple[Optional[torch.Tensor]]]:
"""Return an iterator that generates output tensors"""
if self.num_out_streams == 0:
raise RuntimeError("No output stream is configured.")
while True:
if self.fill_buffer():
break
yield self.pop_chunks()
while True:
chunks = self.pop_chunks()
if all(c is None for c in chunks):
return
yield chunks
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment