Commit 22788a8f authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add `buffer_chunk_size=-1` option (#2969)

Summary:
This commit adds `buffer_chunk_size=-1`, which does not drop buffered frames.

Pull Request resolved: https://github.com/pytorch/audio/pull/2969

Reviewed By: xiaohui-zhang

Differential Revision: D42403467

Pulled By: mthrok

fbshipit-source-id: a0847e6878874ce7e4b0ec3f56e5fbb8ebdb5992
parent d1cc1da6
......@@ -530,6 +530,59 @@ class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestC
assert chunk.shape == torch.Size(shape)
def test_invalid_chunk_option(self):
"""Passing invalid `frames_per_chunk` and `buffer_chunk_size` raises error"""
s = StreamReader(self.get_src())
for fpc, bcs in ((0, 3), (3, 0), (-2, 3), (3, -2)):
with self.assertRaises(RuntimeError):
s.add_audio_stream(frames_per_chunk=fpc, buffer_chunk_size=bcs)
with self.assertRaises(RuntimeError):
s.add_video_stream(frames_per_chunk=fpc, buffer_chunk_size=bcs)
def test_unchunked_stream(self):
"""`frames_per_chunk=-1` disable chunking.
When chunking is disabled, frames contained in one AVFrame become one chunk.
For video, that is always one frame, but for audio, it depends.
"""
s = StreamReader(self.get_src())
s.add_video_stream(frames_per_chunk=-1, buffer_chunk_size=10000)
s.add_audio_stream(frames_per_chunk=-1, buffer_chunk_size=10000)
s.process_all_packets()
video, audio = s.pop_chunks()
assert video.shape == torch.Size([390, 3, 270, 480])
assert audio.shape == torch.Size([208896, 2])
def test_buffer_chunk_size(self):
"""`buffer_chunk_size=-1` does not drop frames."""
src = self.get_src()
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=30, buffer_chunk_size=-1)
s.add_audio_stream(frames_per_chunk=16000, buffer_chunk_size=-1)
s.process_all_packets()
for _ in range(13):
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([16000, 2])
video, audio = s.pop_chunks()
assert video is None
assert audio.shape == torch.Size([896, 2])
if self.test_type == "fileobj":
src.seek(0)
s = StreamReader(src)
s.add_video_stream(frames_per_chunk=30, buffer_chunk_size=3)
s.add_audio_stream(frames_per_chunk=16000, buffer_chunk_size=3)
s.process_all_packets()
for _ in range(2):
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([16000, 2])
video, audio = s.pop_chunks()
assert video.shape == torch.Size([30, 3, 270, 480])
assert audio.shape == torch.Size([896, 2])
def _to_fltp(original):
"""Convert Tensor to float32 with value range [-1, 1]"""
......
......@@ -56,15 +56,17 @@ void ChunkedAudioBuffer::push_tensor(torch::Tensor frame) {
// Trim
// If frames_per_chunk > 0, we only retain the following number of frames and
// Discard older frames.
int64_t max_frames = num_chunks * frames_per_chunk;
while (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
if (num_chunks > 0) {
int64_t max_frames = num_chunks * frames_per_chunk;
while (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
}
}
}
......@@ -78,15 +80,17 @@ void ChunkedVideoBuffer::push_tensor(const torch::Tensor& frame) {
num_buffered_frames += frame.size(0);
// Trim
int64_t max_frames = num_chunks * frames_per_chunk;
if (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
if (num_chunks > 0) {
int64_t max_frames = num_chunks * frames_per_chunk;
if (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
}
}
}
......
......@@ -12,6 +12,16 @@ std::unique_ptr<Buffer> get_buffer(
int frames_per_chunk,
int num_chunks,
const torch::Device& device) {
TORCH_CHECK(
frames_per_chunk > 0 || frames_per_chunk == -1,
"`frames_per_chunk` must be positive or -1. Found: ",
frames_per_chunk);
TORCH_CHECK(
num_chunks > 0 || num_chunks == -1,
"`num_chunks` must be positive or -1. Found: ",
num_chunks);
switch (type) {
case AVMEDIA_TYPE_AUDIO: {
if (frames_per_chunk < 0) {
......
......@@ -216,11 +216,16 @@ def _format_doc(**kwargs):
_frames_per_chunk = """Number of frames returned as one chunk.
If the source stream is exhausted before enough frames are buffered,
then the chunk is returned as-is."""
then the chunk is returned as-is.
Providing ``-1`` disables chunking and :py:func:`pop_chunks` method
will concatenate all the buffered frames and return it."""
_buffer_chunk_size = """Internal buffer size.
When the number of chunks buffered exceeds this number, old frames are
dropped.
dropped. For example, if `frames_per_chunk` is 5 and `buffer_chunk_size` is
3, then frames older than 15 are dropped.
Providing ``-1`` disables this behavior.
Default: ``3``."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment