Unverified Commit 4406a6bb authored by moto's avatar moto Committed by GitHub
Browse files

Add AMB/AMR-NB/AMR-WB support to "sox_io" backend (#1066)

parent 2a02d7f5
...@@ -89,6 +89,8 @@ def _get_extra_objects(): ...@@ -89,6 +89,8 @@ def _get_extra_objects():
'libvorbisfile.a', 'libvorbisfile.a',
'libvorbis.a', 'libvorbis.a',
'libogg.a', 'libogg.a',
'libopencore-amrnb.a',
'libopencore-amrwb.a',
] ]
for lib in libs: for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib)) objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
......
...@@ -122,6 +122,36 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -122,6 +122,36 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.num_frames == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@skipIfNoExtension @skipIfNoExtension
class TestInfoOpus(PytorchTestCase): class TestInfoOpus(PytorchTestCase):
......
...@@ -142,6 +142,53 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -142,6 +142,53 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load amb format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.amb')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate amb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amb with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_amr_nb(self, duration):
"""`sox_io_backend.load` can load amr-nb format.
This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
path = self.get_temp_path('1.original.amr-nb')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate amr-nb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amr-nb with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
...@@ -260,6 +307,20 @@ class TestLoad(LoadTestBase): ...@@ -260,6 +307,20 @@ class TestLoad(LoadTestBase):
"""`sox_io_backend.load` can load sph format correctly.""" """`sox_io_backend.load` can load sph format correctly."""
self.assert_sphere(sample_rate, num_channels, duration=1) self.assert_sphere(sample_rate, num_channels, duration=1)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load sph format correctly."""
self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1)
def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_amr_nb(duration=1)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
......
...@@ -200,6 +200,68 @@ class SaveTestBase(TempDirMixin, PytorchTestCase): ...@@ -200,6 +200,68 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected) self.assertEqual(found, expected)
def assert_amb(self, dtype, sample_rate, num_channels, duration):
"""`sox_io_backend.save` can save amb format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
amb_path = self.get_temp_path('2.1.torchaudio.amb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amb_path_sox = self.get_temp_path('3.1.sox.amb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amb with torchaudio
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate)
# 2.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to amb with SoX
sox_utils.convert_audio_file(src_path, amb_path_sox)
# 3.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def assert_amr_nb(self, duration):
"""`sox_io_backend.save` can save amr_nb format.
This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
src_path = self.get_temp_path('1.reference.wav')
amr_path = self.get_temp_path('2.1.torchaudio.amr-nb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amr_path_sox = self.get_temp_path('3.1.sox.amr-nb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to amr_nb with SoX
sox_utils.convert_audio_file(src_path, amr_path_sox)
# 3.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
...@@ -302,6 +364,19 @@ class TestSave(SaveTestBase): ...@@ -302,6 +364,19 @@ class TestSave(SaveTestBase):
"""`sox_io_backend.save` can save sph format.""" """`sox_io_backend.save` can save sph format."""
self.assert_sphere(sample_rate, num_channels, duration=1) self.assert_sphere(sample_rate, num_channels, duration=1)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save amb format."""
self.assert_amb(dtype, sample_rate, num_channels, duration=1)
def test_amr_nb(self):
"""`sox_io_backend.save` can save amr-nb format."""
self.assert_amr_nb(duration=1)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
......
...@@ -16,6 +16,14 @@ ExternalProject_Add(libmad ...@@ -16,6 +16,14 @@ ExternalProject_Add(libmad
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/libmad/configure ${COMMON_ARGS} CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/libmad/configure ${COMMON_ARGS}
) )
ExternalProject_Add(amr
PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://sourceforge.net/projects/opencore-amr/files/opencore-amr/opencore-amr-0.1.5.tar.gz
URL_HASH SHA256=2c006cb9d5f651bfb5e60156dbff6af3c9d35c7bbcc9015308c0aff1e14cd341
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/src/amr/configure ${COMMON_ARGS}
)
ExternalProject_Add(libmp3lame ExternalProject_Add(libmp3lame
PREFIX ${CMAKE_CURRENT_SOURCE_DIR} PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DOWNLOAD_DIR ${ARCHIVE_DIR} DOWNLOAD_DIR ${ARCHIVE_DIR}
...@@ -72,11 +80,11 @@ ExternalProject_Add(opusfile ...@@ -72,11 +80,11 @@ ExternalProject_Add(opusfile
ExternalProject_Add(libsox ExternalProject_Add(libsox
PREFIX ${CMAKE_CURRENT_SOURCE_DIR} PREFIX ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad DEPENDS libogg libflac libvorbis opusfile libmp3lame libmad amr
DOWNLOAD_DIR ${ARCHIVE_DIR} DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2 URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses. # OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
# See https://github.com/pytorch/audio/pull/1026 # See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --disable-openmp CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
) )
...@@ -40,18 +40,19 @@ def load( ...@@ -40,18 +40,19 @@ def load(
This function can handle all the codecs that underlying libsox can handle, This function can handle all the codecs that underlying libsox can handle,
however it is tested on the following formats; however it is tested on the following formats;
* WAV * WAV, AMB
* 32-bit floating-point * 32-bit floating-point
* 32-bit signed integer * 32-bit signed integer
* 16-bit signed integer * 16-bit signed integer
* 8-bit unsigned integer * 8-bit unsigned integer (WAV only)
* MP3 * MP3
* FLAC * FLAC
* OGG/VORBIS * OGG/VORBIS
* OPUS * OPUS
* SPHERE * SPHERE
* AMR-NB
To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not To load ``MP3``, ``FLAC``, ``OGG/VORBIS``, ``OPUS`` and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
...@@ -119,7 +120,7 @@ def save( ...@@ -119,7 +120,7 @@ def save(
Note: Note:
Supported formats are; Supported formats are;
* WAV * WAV, AMB
* 32-bit floating-point * 32-bit floating-point
* 32-bit signed integer * 32-bit signed integer
...@@ -130,6 +131,7 @@ def save( ...@@ -130,6 +131,7 @@ def save(
* FLAC * FLAC
* OGG/VORBIS * OGG/VORBIS
* SPHERE * SPHERE
* AMR-NB
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox`` handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
...@@ -160,7 +162,7 @@ def save( ...@@ -160,7 +162,7 @@ def save(
filepath = str(filepath) filepath = str(filepath)
if compression is None: if compression is None:
ext = str(filepath).split('.')[-1].lower() ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph']: if ext in ['wav', 'sph', 'amb', 'amr-nb']:
compression = 0. compression = 0.
elif ext == 'mp3': elif ext == 'mp3':
compression = -4.5 compression = -4.5
......
...@@ -85,11 +85,17 @@ void save_audio_file( ...@@ -85,11 +85,17 @@ void save_audio_file(
const std::string& file_name, const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal, const c10::intrusive_ptr<TensorSignal>& signal,
const double compression) { const double compression) {
const auto tensor = signal->getTensor(); auto tensor = signal->tensor;
validate_input_tensor(tensor); validate_input_tensor(tensor);
const auto filetype = get_filetype(file_name); const auto filetype = get_filetype(file_name);
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(signal->channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio.");
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
}
const auto signal_info = get_signalinfo(signal.get(), filetype); const auto signal_info = get_signalinfo(signal.get(), filetype);
const auto encoding_info = const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression); get_encodinginfo(filetype, tensor.dtype(), compression);
......
...@@ -223,7 +223,7 @@ sox_encoding_t get_encoding( ...@@ -223,7 +223,7 @@ sox_encoding_t get_encoding(
return SOX_ENCODING_FLAC; return SOX_ENCODING_FLAC;
if (filetype == "ogg" || filetype == "vorbis") if (filetype == "ogg" || filetype == "vorbis")
return SOX_ENCODING_VORBIS; return SOX_ENCODING_VORBIS;
if (filetype == "wav") { if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8) if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED; return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16) if (dtype == torch::kInt16)
...@@ -236,7 +236,9 @@ sox_encoding_t get_encoding( ...@@ -236,7 +236,9 @@ sox_encoding_t get_encoding(
} }
if (filetype == "sph") if (filetype == "sph")
return SOX_ENCODING_SIGN2; return SOX_ENCODING_SIGN2;
throw std::runtime_error("Unsupported file type."); if (filetype == "amr-nb")
return SOX_ENCODING_AMR_NB;
throw std::runtime_error("Unsupported file type: " + filetype);
} }
unsigned get_precision( unsigned get_precision(
...@@ -248,7 +250,7 @@ unsigned get_precision( ...@@ -248,7 +250,7 @@ unsigned get_precision(
return 24; return 24;
if (filetype == "ogg" || filetype == "vorbis") if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC; return SOX_UNSPEC;
if (filetype == "wav") { if (filetype == "wav" || filetype == "amb") {
if (dtype == torch::kUInt8) if (dtype == torch::kUInt8)
return 8; return 8;
if (dtype == torch::kInt16) if (dtype == torch::kInt16)
...@@ -261,7 +263,13 @@ unsigned get_precision( ...@@ -261,7 +263,13 @@ unsigned get_precision(
} }
if (filetype == "sph") if (filetype == "sph")
return 32; return 32;
throw std::runtime_error("Unsupported file type."); if (filetype == "amr-nb") {
TORCH_INTERNAL_ASSERT(
dtype == torch::kInt16,
"When saving to AMR-NB format, the input tensor must be int16 type.");
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
} }
sox_signalinfo_t get_signalinfo( sox_signalinfo_t get_signalinfo(
...@@ -287,11 +295,13 @@ sox_encodinginfo_t get_encodinginfo( ...@@ -287,11 +295,13 @@ sox_encodinginfo_t get_encodinginfo(
return compression; return compression;
if (filetype == "ogg" || filetype == "vorbis") if (filetype == "ogg" || filetype == "vorbis")
return compression; return compression;
if (filetype == "wav") if (filetype == "wav" || filetype == "amb")
return 0.; return 0.;
if (filetype == "sph") if (filetype == "sph")
return 0.; return 0.;
throw std::runtime_error("Unsupported file type."); if (filetype == "amr-nb")
return 0.;
throw std::runtime_error("Unsupported file type: " + filetype);
}(); }();
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment