Commit 22788a8f authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add `buffer_chunk_size=-1` option (#2969)

Summary:
This commit adds `buffer_chunk_size=-1`, which does not drop buffered frames.

Pull Request resolved: https://github.com/pytorch/audio/pull/2969

Reviewed By: xiaohui-zhang

Differential Revision: D42403467

Pulled By: mthrok

fbshipit-source-id: a0847e6878874ce7e4b0ec3f56e5fbb8ebdb5992
parent d1cc1da6
...@@ -530,6 +530,59 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC ...@@ -530,6 +530,59 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
assert chunk.shape == torch.Size(shape) assert chunk.shape == torch.Size(shape)
def test_invalid_chunk_option(self):
"""Passing invalid `frames_per_chunk` and `buffer_chunk_size` raises error"""
s = StreamReader(self.get_src())
for fpc, bcs in ((0, 3), (3, 0), (-2, 3), (3, -2)):
with self.assertRaises(RuntimeError):
s.add_audio_stream(frames_per_chunk=fpc, buffer_chunk_size=bcs)
with self.assertRaises(RuntimeError):
s.add_video_stream(frames_per_chunk=fpc, buffer_chunk_size=bcs)
def test_unchunked_stream(self):
"""`frames_per_chunk=-1` disable chunking.
When chunking is disabled, frames contained in one AVFrame become one chunk.
For video, that is always one frame, but for audio, it depends.
"""
s = StreamReader(self.get_src())
s.add_video_stream(frames_per_chunk=-1, buffer_chunk_size=10000)
s.add_audio_stream(frames_per_chunk=-1, buffer_chunk_size=10000)
s.process_all_packets()
video, audio = s.pop_chunks()
assert video.shape == torch.Size([390, 3, 270, 480])
assert audio.shape == torch.Size([208896, 2])
def test_buffer_chunk_size(self):
"""`buffer_chunk_size=-1` does not drop frames."""
src = self.get_src()
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=30, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=16000, buffer_chunk_size=-1)
s.process_all_packets()
for _ in range(13):
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([16000, 2])
video, audio = s.pop_chunks()
assert video is None
assert audio.shape == torch.Size([896, 2])
if self.test_type == "fileobj":
src.seek(0)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=30, buffer_chunk_size=3)
s.add_audio_stream(frames_per_chunk=16000, buffer_chunk_size=3)
s.process_all_packets()
for _ in range(2):
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([16000, 2])
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([896, 2])
def _to_fltp(original): def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]""" """Convert Tensor to float32 with value range [-1, 1]"""
......
...@@ -56,15 +56,17 @@ void ChunkedAudioBuffer::push_tensor(torch::Tensor frame) { ...@@ -56,15 +56,17 @@ void ChunkedAudioBuffer::push_tensor(torch::Tensor frame) {
// Trim // Trim
// If frames_per_chunk > 0, we only retain the following number of frames and // If frames_per_chunk > 0, we only retain the following number of frames and
// Discard older frames. // Discard older frames.
int64_t max_frames = num_chunks * frames_per_chunk; if (num_chunks > 0) {
while (num_buffered_frames > max_frames) { int64_t max_frames = num_chunks * frames_per_chunk;
TORCH_WARN_ONCE( while (num_buffered_frames > max_frames) {
"The number of buffered frames exceeded the buffer size. " TORCH_WARN_ONCE(
"Dropping the old frames. " "The number of buffered frames exceeded the buffer size. "
"To avoid this, you can set a higher buffer_chunk_size value."); "Dropping the old frames. "
torch::Tensor& t = chunks.front(); "To avoid this, you can set a higher buffer_chunk_size value.");
num_buffered_frames -= t.size(0); torch::Tensor& t = chunks.front();
chunks.pop_front(); num_buffered_frames -= t.size(0);
chunks.pop_front();
}
} }
} }
...@@ -78,15 +80,17 @@ void ChunkedVideoBuffer::push_tensor(const torch::Tensor& frame) { ...@@ -78,15 +80,17 @@ void ChunkedVideoBuffer::push_tensor(const torch::Tensor& frame) {
num_buffered_frames += frame.size(0); num_buffered_frames += frame.size(0);
// Trim // Trim
int64_t max_frames = num_chunks * frames_per_chunk; if (num_chunks > 0) {
if (num_buffered_frames > max_frames) { int64_t max_frames = num_chunks * frames_per_chunk;
TORCH_WARN_ONCE( if (num_buffered_frames > max_frames) {
"The number of buffered frames exceeded the buffer size. " TORCH_WARN_ONCE(
"Dropping the old frames. " "The number of buffered frames exceeded the buffer size. "
"To avoid this, you can set a higher buffer_chunk_size value."); "Dropping the old frames. "
torch::Tensor& t = chunks.front(); "To avoid this, you can set a higher buffer_chunk_size value.");
num_buffered_frames -= t.size(0); torch::Tensor& t = chunks.front();
chunks.pop_front(); num_buffered_frames -= t.size(0);
chunks.pop_front();
}
} }
} }
......
...@@ -12,6 +12,16 @@ std::unique_ptr<Buffer> get_buffer( ...@@ -12,6 +12,16 @@ std::unique_ptr<Buffer> get_buffer(
int frames_per_chunk, int frames_per_chunk,
int num_chunks, int num_chunks,
const torch::Device& device) { const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
switch (type) { switch (type) {
case AVMEDIA_TYPE_AUDIO: { case AVMEDIA_TYPE_AUDIO: {
if (frames_per_chunk < 0) { if (frames_per_chunk < 0) {
......
...@@ -216,11 +216,16 @@ def _format_doc(**kwargs): ...@@ -216,11 +216,16 @@ def _format_doc(**kwargs):
_frames_per_chunk = """Number of frames returned as one chunk. _frames_per_chunk = """Number of frames returned as one chunk.
If the source stream is exhausted before enough frames are buffered, If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is.""" then the chunk is returned as-is.
Providing ``-1`` disables chunking and :py:func:`pop_chunks` method
will concatenate all the buffered frames and return it."""
_buffer_chunk_size = """Internal buffer size. _buffer_chunk_size = """Internal buffer size.
When the number of chunks buffered exceeds this number, old frames are When the number of chunks buffered exceeds this number, old frames are
dropped. dropped. For example, if `frames_per_chunk` is 5 and `buffer_chunk_size` is
3, then frames older than 15 are dropped.
Providing ``-1`` disables this behavior.
Default: ``3``.""" Default: ``3``."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment