Unverified Commit f16f74af authored by moto's avatar moto Committed by GitHub
Browse files

[BC Breaking] Split `list_formats()` for read and write (#811)

* Separate sox list format function for read and write

* Guard MP3 smoke test
parent d346cacb
import itertools
import unittest
from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_module_available
from parameterized import parameterized
from ..common_utils import (
......@@ -12,6 +15,13 @@ from ..common_utils import (
from .common import name_func
skipIfNoMP3 = unittest.skipIf(
not is_module_available('torchaudio._torchaudio') or
'mp3' not in sox_utils.list_read_formats() or
'mp3' not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3')
@skipIfNoExtension
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
......@@ -53,6 +63,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
......
......@@ -38,7 +38,12 @@ class TestSoxUtils(PytorchTestCase):
assert 'phaser' in effects
assert 'gain' in effects
def test_list_formats(self):
"""`list_formats` returns the list of supported formats"""
formats = sox_utils.list_formats()
def test_list_read_formats(self):
"""`list_read_formats` returns the list of supported formats"""
formats = sox_utils.list_read_formats()
assert 'wav' in formats
def test_list_write_formats(self):
"""`list_write_formats` returns the list of supported formats"""
formats = sox_utils.list_write_formats()
assert 'opus' not in formats
......@@ -27,7 +27,10 @@ static auto registerSetSoxOptions =
.op("torchaudio::sox_utils_set_buffer_size",
&sox_utils::set_buffer_size)
.op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects)
.op("torchaudio::sox_utils_list_formats", &sox_utils::list_formats);
.op("torchaudio::sox_utils_list_read_formats",
&sox_utils::list_read_formats)
.op("torchaudio::sox_utils_list_write_formats",
&sox_utils::list_write_formats);
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
......
......@@ -37,11 +37,24 @@ std::vector<std::vector<std::string>> list_effects() {
return effects;
}
std::vector<std::string> list_formats() {
std::vector<std::string> list_write_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
for (const char* const* names = fns->fn()->names; *names; ++names) {
if (!strchr(*names, '/'))
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->write)
formats.emplace_back(*names);
}
}
return formats;
}
std::vector<std::string> list_read_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->read)
formats.emplace_back(*names);
}
}
......
......@@ -22,7 +22,9 @@ void set_buffer_size(const int64_t buffer_size);
std::vector<std::vector<std::string>> list_effects();
std::vector<std::string> list_formats();
std::vector<std::string> list_read_formats();
std::vector<std::string> list_write_formats();
/// Class for exchanging signal infomation (tensor + meta data) between
/// C++ and Python for read/write operation.
......
......@@ -75,10 +75,20 @@ def list_effects() -> Dict[str, str]:
@_mod_utils.requires_module('torchaudio._torchaudio')
def list_formats() -> List[str]:
"""List the supported audio formats
def list_read_formats() -> List[str]:
"""List the supported audio formats for read
Returns:
List[str]: List of supported audio formats
"""
return torch.ops.torchaudio.sox_utils_list_formats()
return torch.ops.torchaudio.sox_utils_list_read_formats()
@_mod_utils.requires_module('torchaudio._torchaudio')
def list_write_formats() -> List[str]:
"""List the supported audio formats for write
Returns:
List[str]: List of supported audio formats
"""
return torch.ops.torchaudio.sox_utils_list_write_formats()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment