Unverified Commit e7b43dde authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Make buffer size for function info configurable (#1634)

parent 8ec6b873
from contextlib import contextmanager
import io import io
import os import os
import itertools import itertools
...@@ -5,6 +6,7 @@ import tarfile ...@@ -5,6 +6,7 @@ import tarfile
from parameterized import parameterized from parameterized import parameterized
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio.utils.sox_utils import get_buffer_size, set_buffer_size
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.backend.common import ( from torchaudio_unittest.backend.common import (
...@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase): ...@@ -293,24 +295,33 @@ class TestLoadWithoutExtension(PytorchTestCase):
class FileObjTestBase(TempDirMixin): class FileObjTestBase(TempDirMixin):
def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames): def _gen_file(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self.get_temp_path(f'test.{ext}') path = self.get_temp_path(f'test.{ext}')
bit_depth = sox_utils.get_bit_depth(dtype) bit_depth = sox_utils.get_bit_depth(dtype)
duration = num_frames / sample_rate duration = num_frames / sample_rate
comment_file = self._gen_comment_file(comments) if comments else None
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels=num_channels, path, sample_rate, num_channels=num_channels,
encoding=sox_utils.get_encoding(dtype), encoding=sox_utils.get_encoding(dtype),
bit_depth=bit_depth, bit_depth=bit_depth,
duration=duration) duration=duration,
comment_file=comment_file,
)
return path return path
def _gen_comment_file(self, comments):
comment_path = self.get_temp_path("comment.txt")
with open(comment_path, "w") as file_:
file_.writelines(comments)
return comment_path
@skipIfNoSox @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec('sox')
class TestFileObject(FileObjTestBase, PytorchTestCase): class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames): def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames, *, comments=None):
path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames) path = self._gen_file(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
format_ = ext if ext in ['mp3'] else None format_ = ext if ext in ['mp3'] else None
with open(path, 'rb') as fileobj: with open(path, 'rb') as fileobj:
return sox_io_backend.info(fileobj, format_) return sox_io_backend.info(fileobj, format_)
...@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -333,6 +344,15 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
fileobj = tarobj.extractfile(audio_file) fileobj = tarobj.extractfile(audio_file)
return sox_io_backend.info(fileobj, format_) return sox_io_backend.info(fileobj, format_)
@contextmanager
def _set_buffer_size(self, buffer_size):
try:
original_buffer_size = get_buffer_size()
set_buffer_size(buffer_size)
yield
finally:
set_buffer_size(original_buffer_size)
@parameterized.expand([ @parameterized.expand([
('wav', "float32"), ('wav', "float32"),
('wav', "int32"), ('wav', "int32"),
...@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -359,6 +379,34 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.bits_per_sample == bits_per_sample assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([
('vorbis', "float32"),
])
def test_fileobj_large_header(self, ext, dtype):
"""
For audio file with header size exceeding default buffer size:
- Querying audio via file object without enlarging buffer size fails.
- Querying audio via file object after enlarging buffer size succeeds.
"""
sample_rate = 16000
num_frames = 3 * sample_rate
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with self._set_buffer_size(16384):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
bits_per_sample = get_bits_per_sample(ext, dtype)
num_frames = 0 if ext in ['mp3', 'vorbis'] else num_frames
assert sinfo.sample_rate == sample_rate
assert sinfo.num_channels == num_channels
assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype)
@parameterized.expand([ @parameterized.expand([
('wav', "float32"), ('wav', "float32"),
('wav', "int32"), ('wav', "int32"),
......
...@@ -25,7 +25,7 @@ def get_bit_depth(dtype): ...@@ -25,7 +25,7 @@ def get_bit_depth(dtype):
def gen_audio_file( def gen_audio_file(
path, sample_rate, num_channels, path, sample_rate, num_channels,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, *, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, comment_file=None,
): ):
"""Generate synthetic audio file with `sox` command.""" """Generate synthetic audio file with `sox` command."""
if path.endswith('.wav'): if path.endswith('.wav'):
...@@ -53,6 +53,8 @@ def gen_audio_file( ...@@ -53,6 +53,8 @@ def gen_audio_file(
command += ['--bits', str(bit_depth)] command += ['--bits', str(bit_depth)]
if encoding is not None: if encoding is not None:
command += ['--encoding', str(encoding)] command += ['--encoding', str(encoding)]
if comment_file is not None:
command += ['--comment-file', str(comment_file)]
command += [ command += [
str(path), str(path),
'synth', str(duration), # synthesizes for the given duration [sec] 'synth', str(duration), # synthesizes for the given duration [sec]
......
...@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj( ...@@ -161,7 +161,10 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
// //
// See: // See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html // https://xiph.org/vorbis/doc/Vorbis_I_spec.html
auto capacity = 4096; const int kDefaultCapacityInBytes = 4096;
auto capacity = (sox_get_globals()->bufsiz > kDefaultCapacityInBytes)
? sox_get_globals()->bufsiz
: kDefaultCapacityInBytes;
std::string buffer(capacity, '\0'); std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data()); auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf); auto num_read = read_fileobj(&fileobj, capacity, buf);
......
...@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) { ...@@ -22,6 +22,10 @@ void set_buffer_size(const int64_t buffer_size) {
sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size); sox_get_globals()->bufsiz = static_cast<size_t>(buffer_size);
} }
int64_t get_buffer_size() {
return sox_get_globals()->bufsiz;
}
std::vector<std::vector<std::string>> list_effects() { std::vector<std::vector<std::string>> list_effects() {
std::vector<std::vector<std::string>> effects; std::vector<std::vector<std::string>> effects;
for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) { for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) {
...@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -538,6 +542,9 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def( m.def(
"torchaudio::sox_utils_list_write_formats", "torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats); &torchaudio::sox_utils::list_write_formats);
m.def(
"torchaudio::sox_utils_get_buffer_size",
&torchaudio::sox_utils::get_buffer_size);
} }
} // namespace sox_utils } // namespace sox_utils
......
...@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads); ...@@ -24,6 +24,8 @@ void set_use_threads(const bool use_threads);
void set_buffer_size(const int64_t buffer_size); void set_buffer_size(const int64_t buffer_size);
int64_t get_buffer_size();
std::vector<std::vector<std::string>> list_effects(); std::vector<std::vector<std::string>> list_effects();
std::vector<std::string> list_read_formats(); std::vector<std::string> list_read_formats();
......
...@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]: ...@@ -90,3 +90,13 @@ def list_write_formats() -> List[str]:
List[str]: List of supported audio formats List[str]: List of supported audio formats
""" """
return torch.ops.torchaudio.sox_utils_list_write_formats() return torch.ops.torchaudio.sox_utils_list_write_formats()
@_mod_utils.requires_sox()
def get_buffer_size() -> int:
"""Get buffer size for sox effect chain
Returns:
int: size in bytes of buffers used for processing audio.
"""
return torch.ops.torchaudio.sox_utils_get_buffer_size()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment