Unverified Commit e7b43dde authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Make buffer size for function info configurable (#1634)

parent 8ec6b873
from contextlib import contextmanager
import io
import os
import itertools
......@@ -5,6 +6,7 @@ import tarfile
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import (
......@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase):
class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f'test.{ext}')
bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file(
path, sample_rate, num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth,
duration=duration)
duration=duration,
comment_file=comment_file,
)
return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
@skipIfNoSox
@skipIfNoExec('sox')
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames)
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as fileobj:
return sox_io_backend.info(fileobj, format_)
......@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_)
@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
......@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('vorbis', "float32"),
])
def test_fileobj_large_header(self, ext, dtype):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('wav', "float32"),
('wav', "int32"),
......
......@@ -25,7 +25,7 @@ def get_bit_depth(dtype):
def gen_audio_file(
path, sample_rate, num_channels,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, comment_file=None,
):
"""Generate synthetic audio file with `sox` command."""
if path.endswith('.wav'):
......@@ -53,6 +53,8 @@ def gen_audio_file(
command += ['--bits', str(bit_depth)]
if encoding is not None:
command += ['--encoding', str(encoding)]
if comment_file is not None:
command += ['--comment-file', str(comment_file)]
command += [
str(path),
'synth', str(duration), # synthesizes for the given duration [sec]
......
......@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto capacity = 4096;
const int kDefaultCapacityInBytes = 4096;
auto capacity = (sox_get_globals()->bufsiz > kDefaultCapacityInBytes)
? sox_get_globals()->bufsiz
: kDefaultCapacityInBytes;
std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
......
......@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) {
sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
}
int64_t get_buffer_size() {
return sox_get_globals()->bufsiz;
}
std::vector<std::vector<std::string>> list_effects() {
std::vector<std::vector<std::string>> effects;
for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
......@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats);
m.def(
"torchaudio::sox_utils_get_buffer_size",
&torchaudio::sox_utils::get_buffer_size);
}
} // namespace sox_utils
......
......@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads);
void set_buffer_size(const int64_t buffer_size);
int64_t get_buffer_size();
std::vector<std::vector<std::string>> list_effects();
std::vector<std::string> list_read_formats();
......
......@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]:
List[str]: List of supported audio formats
"""
return torch.ops.torchaudio.sox_utils_list_write_formats()
@_mod_utils.requires_sox()
def get_buffer_size() -> int:
"""Get buffer size for sox effect chain
Returns:
int: size in bytes of buffers used for processing audio.
"""
return torch.ops.torchaudio.sox_utils_get_buffer_size()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment