Unverified Commit f16f74af authored by moto's avatar moto Committed by GitHub
Browse files

[BC Breaking] Split `list_formats()` for read and write (#811)

* Separate sox list format function for read and write

* Guard MP3 smoke test
parent d346cacb
import itertools import itertools
import unittest
from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_module_available
from parameterized import parameterized from parameterized import parameterized
from ..common_utils import ( from ..common_utils import (
...@@ -12,6 +15,13 @@ from ..common_utils import ( ...@@ -12,6 +15,13 @@ from ..common_utils import (
from .common import name_func from .common import name_func
skipIfNoMP3 = unittest.skipIf(
not is_module_available('torchaudio._torchaudio') or
'mp3' not in sox_utils.list_read_formats() or
'mp3' not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3')
@skipIfNoExtension @skipIfNoExtension
class SmokeTest(TempDirMixin, TorchaudioTestCase): class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format """Run smoke test on various audio format
...@@ -53,6 +63,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -53,6 +63,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
[1, 2], [1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)), name_func=name_func) )), name_func=name_func)
@skipIfNoMP3
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""Run smoke test on mp3 format""" """Run smoke test on mp3 format"""
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate) self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
......
...@@ -38,7 +38,12 @@ class TestSoxUtils(PytorchTestCase): ...@@ -38,7 +38,12 @@ class TestSoxUtils(PytorchTestCase):
assert 'phaser' in effects assert 'phaser' in effects
assert 'gain' in effects assert 'gain' in effects
def test_list_formats(self): def test_list_read_formats(self):
"""`list_formats` returns the list of supported formats""" """`list_read_formats` returns the list of supported formats"""
formats = sox_utils.list_formats() formats = sox_utils.list_read_formats()
assert 'wav' in formats assert 'wav' in formats
def test_list_write_formats(self):
"""`list_write_formats` returns the list of supported formats"""
formats = sox_utils.list_write_formats()
assert 'opus' not in formats
...@@ -27,7 +27,10 @@ static auto registerSetSoxOptions = ...@@ -27,7 +27,10 @@ static auto registerSetSoxOptions =
.op("torchaudio::sox_utils_set_buffer_size", .op("torchaudio::sox_utils_set_buffer_size",
&sox_utils::set_buffer_size) &sox_utils::set_buffer_size)
.op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects) .op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects)
.op("torchaudio::sox_utils_list_formats", &sox_utils::list_formats); .op("torchaudio::sox_utils_list_read_formats",
&sox_utils::list_read_formats)
.op("torchaudio::sox_utils_list_write_formats",
&sox_utils::list_write_formats);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// sox_io.h // sox_io.h
......
...@@ -37,11 +37,24 @@ std::vector<std::vector<std::string>> list_effects() { ...@@ -37,11 +37,24 @@ std::vector<std::vector<std::string>> list_effects() {
return effects; return effects;
} }
std::vector<std::string> list_formats() { std::vector<std::string> list_write_formats() {
std::vector<std::string> formats; std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
for (const char* const* names = fns->fn()->names; *names; ++names) { const sox_format_handler_t* handler = fns->fn();
if (!strchr(*names, '/')) for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->write)
formats.emplace_back(*names);
}
}
return formats;
}
std::vector<std::string> list_read_formats() {
std::vector<std::string> formats;
for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) {
const sox_format_handler_t* handler = fns->fn();
for (const char* const* names = handler->names; *names; ++names) {
if (!strchr(*names, '/') && handler->read)
formats.emplace_back(*names); formats.emplace_back(*names);
} }
} }
......
...@@ -22,7 +22,9 @@ void set_buffer_size(const int64_t buffer_size); ...@@ -22,7 +22,9 @@ void set_buffer_size(const int64_t buffer_size);
std::vector<std::vector<std::string>> list_effects(); std::vector<std::vector<std::string>> list_effects();
std::vector<std::string> list_formats(); std::vector<std::string> list_read_formats();
std::vector<std::string> list_write_formats();
/// Class for exchanging signal infomation (tensor + meta data) between /// Class for exchanging signal infomation (tensor + meta data) between
/// C++ and Python for read/write operation. /// C++ and Python for read/write operation.
......
...@@ -75,10 +75,20 @@ def list_effects() -> Dict[str, str]: ...@@ -75,10 +75,20 @@ def list_effects() -> Dict[str, str]:
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
def list_formats() -> List[str]: def list_read_formats() -> List[str]:
"""List the supported audio formats """List the supported audio formats for read
Returns: Returns:
List[str]: List of supported audio formats List[str]: List of supported audio formats
""" """
return torch.ops.torchaudio.sox_utils_list_formats() return torch.ops.torchaudio.sox_utils_list_read_formats()
@_mod_utils.requires_module('torchaudio._torchaudio')
def list_write_formats() -> List[str]:
"""List the supported audio formats for write
Returns:
List[str]: List of supported audio formats
"""
return torch.ops.torchaudio.sox_utils_list_write_formats()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment