Unverified Commit 3324283c authored by moto's avatar moto Committed by GitHub
Browse files

Add TorchScript-able "save" func to sox_io backend (#732)

This is a part of PRs to add new "sox_io" backend. #726 and depends on #718, #728 and #731.

This PR adds `save` function to "sox_io" backend, which can save Tensor to a file with the following audio formats;
 - `wav`
 - `mp3`
 - `flac`
 - `ogg/vorbis`
parent ea42513f
...@@ -92,28 +92,34 @@ def set_audio_backend(backend): ...@@ -92,28 +92,34 @@ def set_audio_backend(backend):
class TempDirMixin: class TempDirMixin:
"""Mixin to provide easy access to temp dir""" """Mixin to provide easy access to temp dir"""
temp_dir_ = None temp_dir_ = None
base_temp_dir = None
temp_dir = None temp_dir = None
def setUp(self): @classmethod
super().setUp() def setUpClass(cls):
self._init_temp_dir() super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name
def tearDown(self): @classmethod
def tearDownClass(cls):
super().tearDownClass() super().tearDownClass()
self._clean_up_temp_dir() if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()
def _init_temp_dir(self): def setUp(self):
self.temp_dir_ = tempfile.TemporaryDirectory() self.temp_dir = os.path.join(self.base_temp_dir, self.id())
self.temp_dir = self.temp_dir_.name
def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None
def get_temp_path(self, *paths): def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths) path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
class TestBaseMixin: class TestBaseMixin:
......
...@@ -31,7 +31,16 @@ def gen_audio_file( ...@@ -31,7 +31,16 @@ def gen_audio_file(
'Use get_wav_data and save_wav to generate wav file for accurate result.') 'Use get_wav_data and save_wav to generate wav file for accurate result.')
command = [ command = [
'sox', 'sox',
'-V', # verbose '-V3', # verbose
'-R',
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
# search "sox_globals.repeatable"
]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
command += [
'--rate', str(sample_rate), '--rate', str(sample_rate),
'--null', # no input '--null', # no input
'--channels', str(num_channels), '--channels', str(num_channels),
...@@ -60,7 +69,7 @@ def convert_audio_file( ...@@ -60,7 +69,7 @@ def convert_audio_file(
src_path, dst_path, src_path, dst_path,
*, bit_depth=None, compression=None): *, bit_depth=None, compression=None):
"""Convert audio file with `sox` command.""" """Convert audio file with `sox` command."""
command = ['sox', '-V', str(src_path)] command = ['sox', '-V3', '-R', str(src_path)]
if bit_depth is not None: if bit_depth is not None:
command += ['--bits', str(bit_depth)] command += ['--bits', str(bit_depth)]
if compression is not None: if compression is not None:
......
...@@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly""" """`sox_io_backend.info` can check wav file correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
...@@ -44,7 +44,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -44,7 +44,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly""" """`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
...@@ -60,7 +60,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -60,7 +60,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly""" """`sox_io_backend.info` can check mp3 file correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3') path = self.get_temp_path('data.mp3')
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path, sample_rate, num_channels,
compression=bit_rate, duration=duration, compression=bit_rate, duration=duration,
...@@ -79,7 +79,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -79,7 +79,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly""" """`sox_io_backend.info` can check flac file correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac') path = self.get_temp_path('data.flac')
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path, sample_rate, num_channels,
compression=compression_level, duration=duration, compression=compression_level, duration=duration,
...@@ -97,7 +97,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -97,7 +97,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly""" """`sox_io_backend.info` can check vorbis file correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis') path = self.get_temp_path('data.vorbis')
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path, sample_rate, num_channels,
compression=quality_level, duration=duration, compression=quality_level, duration=duration,
......
...@@ -24,7 +24,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -24,7 +24,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
Wav data loaded with sox_io backend should match those with scipy Wav data loaded with sox_io backend should match those with scipy
""" """
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0] expected = load_wav(path, normalize=normalize)[0]
...@@ -58,8 +58,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -58,8 +58,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
By combining i & ii, step 2. and 4. allows to load reference mp3 data By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio without using torchaudio
""" """
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3') path = self.get_temp_path('1.original.mp3')
ref_path = f'{path}.wav' ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate mp3 with sox # 1. Generate mp3 with sox
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
...@@ -80,8 +80,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -80,8 +80,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
This test takes the same strategy as mp3 to compare the result This test takes the same strategy as mp3 to compare the result
""" """
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac') path = self.get_temp_path('1.original.flac')
ref_path = f'{path}.wav' ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate flac with sox # 1. Generate flac with sox
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
...@@ -102,8 +102,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -102,8 +102,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
This test takes the same strategy as mp3 to compare the result This test takes the same strategy as mp3 to compare the result
""" """
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis') path = self.get_temp_path('1.original.vorbis')
ref_path = f'{path}.wav' ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate vorbis with sox # 1. Generate vorbis with sox
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
......
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
)
@skipIfNoExec('sox')
@skipIfNoExtension
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.flac')
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
self.assertEqual(original, data)
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
load_wav,
save_wav,
)
from . import sox_utils
class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`sox_io_backend.save` can save wav format."""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
sox_io_backend.save(path, expected, sample_rate)
found, sr = load_wav(path)
assert sample_rate == sr
self.assertEqual(found, expected)
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.save` can save mp3 format.
mp3 encoding introduces delay and boundary effects so
we convert the resulting mp3 to wav and compare the results there
|
| 1. Generate original wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox
| then save with torchaudio |
v v
mp3 mp3
| |
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
src_path = self.get_temp_path('1.reference.wav')
mp3_path = self.get_temp_path('2.1.torchaudio.mp3')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
mp3_path_sox = self.get_temp_path('3.1.sox.mp3')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to mp3 with torchaudio
sox_io_backend.save(
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
# 2.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to mp3 with SoX
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate)
# 3.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.save` can save flac format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
flc_path = self.get_temp_path('2.1.torchaudio.flac')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
flc_path_sox = self.get_temp_path('3.1.sox.flac')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend.save(
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to flac with SoX
sox_utils.convert_audio_file(src_path, flc_path_sox, compression=compression_level)
# 3.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.save` can save vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
vbs_path = self.get_temp_path('2.1.torchaudio.vorbis')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
vbs_path_sox = self.get_temp_path('3.1.sox.vorbis')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend.save(
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
# 2.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vbs_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to vorbis with SoX
sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level)
# 3.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
# sox's vorbis encoding has some random boundary effect, which cause small number of
# samples yields higher descrepency than the others.
# so we allow small portions of data to be outside of absolute torelance.
# make sure to pass somewhat long duration
atol = 1.0e-4
max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol.
failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel()
if failure_ratio > max_failure_allowed:
# it's failed and this will give a better error message.
self.assertEqual(found, expected, atol=atol, rtol=1.3e-6)
def assert_vorbis(self, *args, **kwargs):
# sox's vorbis encoding has some randomness, so we run tests multiple time
max_retry = 5
error = None
for _ in range(max_retry):
try:
self._assert_vorbis(*args, **kwargs)
break
except AssertionError as e:
error = e
else:
raise error
@skipIfNoExec('sox')
@skipIfNoExtension
class TestSave(SaveTestBase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterized.expand(list(itertools.product(
['float32'],
[16000],
[2],
)), name_func=get_test_name)
def test_wav_large(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save large wav file."""
two_hours = 2 * 60 * 60 * sample_rate
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save mp3 format."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[128],
)), name_func=get_test_name)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save large mp3 file."""
two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save flac format."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=get_test_name)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save large flac file."""
two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save vorbis format."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20)
# note: torchaudio can load large vorbis file, but cannot save large volbis file
# the following test causes Segmentation fault
#
'''
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=get_test_name)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''
@skipIfNoExec('sox')
@skipIfNoExtension
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path('data.wav')
data = get_wav_data('int32', 2, channels_first=channels_first)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8'
], name_func=get_test_name)
def test_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(path, expected, 8000)
found = load_wav(path)[0]
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8',
])
def test_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
data = expected.clone()
sox_io_backend.save(path, data, 8000)
self.assertEqual(data, expected)
import itertools import itertools
from typing import Optional
import torch import torch
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
...@@ -13,8 +14,10 @@ from ..common_utils import ( ...@@ -13,8 +14,10 @@ from ..common_utils import (
from .common import ( from .common import (
get_test_name, get_test_name,
get_wav_data, get_wav_data,
save_wav save_wav,
load_wav,
) )
from . import sox_utils
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
...@@ -26,6 +29,16 @@ def py_load_func(filepath: str, normalize: bool, channels_first: bool): ...@@ -26,6 +29,16 @@ def py_load_func(filepath: str, normalize: bool, channels_first: bool):
filepath, normalize=normalize, channels_first=channels_first) filepath, normalize=normalize, channels_first=channels_first)
def py_save_func(
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
):
sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase): class SoxIO(TempDirMixin, TorchaudioTestCase):
...@@ -41,7 +54,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -41,7 +54,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate) save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('info_func') script_path = self.get_temp_path('info_func.zip')
torch.jit.script(py_info_func).save(script_path) torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path) ts_info_func = torch.jit.load(script_path)
...@@ -65,7 +78,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -65,7 +78,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate) save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('load_func') script_path = self.get_temp_path('load_func.zip')
torch.jit.script(py_load_func).save(script_path) torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path) ts_load_func = torch.jit.load(script_path)
...@@ -76,3 +89,59 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -76,3 +89,59 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
self.assertEqual(py_sr, ts_sr) self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data) self.assertEqual(py_data, ts_data)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_save_wav(self, dtype, sample_rate, num_channels):
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
expected = get_wav_data(dtype, num_channels)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
py_save_func(py_path, expected, sample_rate, True, None)
ts_save_func(ts_path, expected, sample_rate, True, None)
py_data, py_sr = load_wav(py_path)
ts_data, ts_sr = load_wav(ts_path)
self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_save_flac(self, sample_rate, num_channels, compression_level):
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
expected = get_wav_data('float32', num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
py_save_func(py_path, expected, sample_rate, True, compression_level)
ts_save_func(ts_path, expected, sample_rate, True, compression_level)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav'
ts_path_wav = f'{ts_path}.wav'
sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32)
sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32)
py_data, py_sr = load_wav(py_path_wav, normalize=True)
ts_data, ts_sr = load_wav(ts_path_wav, normalize=True)
self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
from typing import Tuple from typing import Tuple, Optional
import torch import torch
from torchaudio._internal import ( from torchaudio._internal import (
...@@ -25,12 +25,7 @@ def load( ...@@ -25,12 +25,7 @@ def load(
This function can handle all the codecs that underlying libsox can handle, however note the This function can handle all the codecs that underlying libsox can handle, however note the
followings. followings.
Note 1: Note:
Current torchaudio's binary release only contains codecs for MP3, FLAC and OGG/VORBIS.
If you need other formats, you need to build torchaudio from source with libsox and
the corresponding codecs. Refer to README for this.
Note 2:
This function is tested on the following formats; This function is tested on the following formats;
- WAV - WAV
- 32-bit floating-point - 32-bit floating-point
...@@ -77,4 +72,58 @@ def load( ...@@ -77,4 +72,58 @@ def load(
return signal.get_tensor(), signal.get_sample_rate() return signal.get_tensor(), signal.get_sample_rate()
@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
frames_per_chunk: int = 65536,
):
"""Save audio data to file.
Supported formats are;
- WAV
- 32-bit floating-point
- 32-bit signed integer
- 16-bit signed integer
- 8-bit unsigned integer
- MP3
- FLAC
- OGG/VORBIS
Args:
filepath: Path to save file.
tensor: Audio data to save. must be 2D tensor.
sample_rate: sampling rate
channels_first: If True, the given tensor is interpreted as ``[channel, time]``.
compression: Used for formats other than WAV. This corresponds to ``-C`` option
of ``sox`` command.
See the detail at http://sox.sourceforge.net/soxformat.html.
- MP3: Either bitrate [kbps] with quality factor, such as ``128.2`` or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``
- FLAC: compression level. Whole number from ``0`` to ``8``.
``8`` is default and highest compression.
- OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest
quality. Default: ``3``.
frames_per_chunk: The number of frames to process (convert to ``int32`` internally
then write to file) at a time.
"""
if compression is None:
ext = str(filepath)[-3:].lower()
if ext == 'wav':
compression = 0.
elif ext == 'mp3':
compression = -4.5
elif ext == 'flac':
compression = 8.
elif ext in ['ogg', 'vorbis']:
compression = 3.
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk)
load_wav = load load_wav = load
...@@ -46,6 +46,14 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op( ...@@ -46,6 +46,14 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op(
decltype(sox_io::load_audio_file), decltype(sox_io::load_audio_file),
&sox_io::load_audio_file>()); &sox_io::load_audio_file>());
static auto registerSaveAudioFile = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression, int frames_per_chunk) -> ()")
.catchAllKernel<
decltype(sox_io::save_audio_file),
&sox_io::save_audio_file>());
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// sox_effects.h // sox_effects.h
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
...@@ -15,7 +15,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) { ...@@ -15,7 +15,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/nullptr)); /*filetype=*/nullptr));
if (sf.get() == nullptr) { if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error opening audio file"); throw std::runtime_error("Error opening audio file");
} }
...@@ -52,7 +52,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file( ...@@ -52,7 +52,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
const int64_t num_total_samples = sf->signal.length; const int64_t num_total_samples = sf->signal.length;
const int64_t sample_start = sf->signal.channels * frame_offset; const int64_t sample_start = sf->signal.channels * frame_offset;
if (sox_seek(sf.get(), sample_start, 0) == SOX_EOF) { if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
throw std::runtime_error("Error reading audio file: offset past EOF."); throw std::runtime_error("Error reading audio file: offset past EOF.");
} }
...@@ -79,7 +79,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file( ...@@ -79,7 +79,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
// Read samples into buffer // Read samples into buffer
std::vector<sox_sample_t> buffer; std::vector<sox_sample_t> buffer;
buffer.reserve(max_samples); buffer.reserve(max_samples);
const int64_t num_samples = sox_read(sf.get(), buffer.data(), max_samples); const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
if (num_samples == 0) { if (num_samples == 0) {
throw std::runtime_error( throw std::runtime_error(
"Error reading audio file: empty file or read operation failed."); "Error reading audio file: empty file or read operation failed.");
...@@ -100,5 +100,51 @@ c10::intrusive_ptr<TensorSignal> load_audio_file( ...@@ -100,5 +100,51 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
tensor, static_cast<int64_t>(sf->signal.rate), channels_first); tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
} }
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression,
const int64_t frames_per_chunk) {
const auto tensor = signal->getTensor();
const auto sample_rate = signal->getSampleRate();
const auto channels_first = signal->getChannelsFirst();
validate_input_tensor(tensor);
const auto filetype = get_filetype(file_name);
const auto signal_info =
get_signalinfo(tensor, sample_rate, channels_first, filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
SoxFormat sf(sox_open_write(
file_name.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error saving audio file: failed to open file.");
}
auto tensor_ = tensor;
if (channels_first) {
tensor_ = tensor_.t();
}
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous();
const size_t numel = chunk.numel();
if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
throw std::runtime_error(
"Error saving audio file: failed to write the entier buffer.");
}
}
}
} // namespace sox_io } // namespace sox_io
} // namespace torchaudio } // namespace torchaudio
...@@ -17,6 +17,11 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file( ...@@ -17,6 +17,11 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const bool normalize = true, const bool normalize = true,
const bool channels_first = true); const bool channels_first = true);
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
const double compression = 0.,
const int64_t frames_per_chunk = 65536);
} // namespace sox_io } // namespace sox_io
} // namespace torchaudio } // namespace torchaudio
......
...@@ -32,12 +32,12 @@ SoxFormat::~SoxFormat() { ...@@ -32,12 +32,12 @@ SoxFormat::~SoxFormat() {
sox_format_t* SoxFormat::operator->() const noexcept { sox_format_t* SoxFormat::operator->() const noexcept {
return fd_; return fd_;
} }
sox_format_t* SoxFormat::get() const noexcept { SoxFormat::operator sox_format_t*() const noexcept {
return fd_; return fd_;
} }
void validate_input_file(const SoxFormat& sf) { void validate_input_file(const SoxFormat& sf) {
if (sf.get() == nullptr) { if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file."); throw std::runtime_error("Error loading audio file: failed to open file.");
} }
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
...@@ -48,6 +48,23 @@ void validate_input_file(const SoxFormat& sf) { ...@@ -48,6 +48,23 @@ void validate_input_file(const SoxFormat& sf) {
} }
} }
void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU.");
}
if (tensor.ndimension() != 2) {
throw std::runtime_error("Input tensor has to be 2D.");
}
const auto dtype = tensor.dtype();
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
dtype == torch::kInt16 || dtype == torch::kUInt8)) {
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
}
}
caffe2::TypeMeta get_dtype( caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding, const sox_encoding_t encoding,
const unsigned precision) { const unsigned precision) {
...@@ -109,5 +126,120 @@ torch::Tensor convert_to_tensor( ...@@ -109,5 +126,120 @@ torch::Tensor convert_to_tensor(
return t.contiguous(); return t.contiguous();
} }
torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
const auto dtype = input_tensor.dtype();
auto tensor = input_tensor;
if (dtype == torch::kFloat32) {
double multi_pos = 2147483647.;
double multi_neg = -2147483648.;
auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
tensor = tensor.to(torch::dtype(torch::kFloat64));
tensor *= mult;
tensor.clamp_(multi_neg, multi_pos);
tensor = tensor.to(torch::dtype(torch::kInt32));
} else if (dtype == torch::kInt32) {
// already denormalized
} else if (dtype == torch::kInt16) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor *= ((tensor != 0) * 65536);
} else if (dtype == torch::kUInt8) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor -= 128;
tensor *= 16777216;
} else {
throw std::runtime_error("Unexpected dtype.");
}
return tensor;
}
const std::string get_filetype(const std::string path) {
std::string ext = path.substr(path.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext;
}
sox_encoding_t get_encoding(
const std::string filetype,
const caffe2::TypeMeta dtype) {
if (filetype == "mp3")
return SOX_ENCODING_MP3;
if (filetype == "flac")
return SOX_ENCODING_FLAC;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_ENCODING_VORBIS;
if (filetype == "wav") {
if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kInt32)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kFloat32)
return SOX_ENCODING_FLOAT;
throw std::runtime_error("Unsupported dtype.");
}
throw std::runtime_error("Unsupported file type.");
}
unsigned get_precision(
const std::string filetype,
const caffe2::TypeMeta dtype) {
if (filetype == "mp3")
return SOX_UNSPEC;
if (filetype == "flac")
return 24;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav") {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
return 16;
if (dtype == torch::kInt32)
return 32;
if (dtype == torch::kFloat32)
return 32;
throw std::runtime_error("Unsupported dtype.");
}
throw std::runtime_error("Unsupported file type.");
}
sox_signalinfo_t get_signalinfo(
const torch::Tensor& tensor,
const int64_t sample_rate,
const bool channels_first,
const std::string filetype) {
return sox_signalinfo_t{
/*rate=*/static_cast<sox_rate_t>(sample_rate),
/*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)),
/*precision=*/get_precision(filetype, tensor.dtype()),
/*length=*/static_cast<uint64_t>(tensor.numel())};
}
sox_encodinginfo_t get_encodinginfo(
const std::string filetype,
const caffe2::TypeMeta dtype,
const double compression) {
const double compression_ = [&]() {
if (filetype == "mp3")
return compression;
if (filetype == "flac")
return compression;
if (filetype == "ogg" || filetype == "vorbis")
return compression;
if (filetype == "wav")
return 0.;
throw std::runtime_error("Unsupported file type.");
}();
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
/*bits_per_sample=*/get_precision(filetype, dtype),
/*compression=*/compression_,
/*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default,
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
}
} // namespace sox_utils } // namespace sox_utils
} // namespace torchaudio } // namespace torchaudio
...@@ -31,7 +31,7 @@ struct SoxFormat { ...@@ -31,7 +31,7 @@ struct SoxFormat {
SoxFormat& operator=(SoxFormat&& other) = delete; SoxFormat& operator=(SoxFormat&& other) = delete;
~SoxFormat(); ~SoxFormat();
sox_format_t* operator->() const noexcept; sox_format_t* operator->() const noexcept;
sox_format_t* get() const noexcept; operator sox_format_t*() const noexcept;
private: private:
sox_format_t* fd_; sox_format_t* fd_;
...@@ -41,6 +41,10 @@ struct SoxFormat { ...@@ -41,6 +41,10 @@ struct SoxFormat {
/// Verify that input file is found, has known encoding, and not empty /// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf); void validate_input_file(const SoxFormat& sf);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const torch::Tensor);
/// ///
/// Get target dtype for the given encoding and precision. /// Get target dtype for the given encoding and precision.
caffe2::TypeMeta get_dtype( caffe2::TypeMeta get_dtype(
...@@ -70,6 +74,27 @@ torch::Tensor convert_to_tensor( ...@@ -70,6 +74,27 @@ torch::Tensor convert_to_tensor(
const bool normalize, const bool normalize,
const bool channels_first); const bool channels_first);
///
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
/// conversion.
torch::Tensor unnormalize_wav(const torch::Tensor);
/// Extract extension from file path
const std::string get_filetype(const std::string path);
/// Get sox_signalinfo_t for passing a torch::Tensor object.
sox_signalinfo_t get_signalinfo(
const torch::Tensor& tensor,
const int64_t sample_rate,
const bool channels_first,
const std::string filetype);
/// Get sox_encofinginfo_t for saving audoi file
sox_encodinginfo_t get_encodinginfo(
const std::string filetype,
const caffe2::TypeMeta dtype,
const double compression);
} // namespace sox_utils } // namespace sox_utils
} // namespace torchaudio } // namespace torchaudio
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment