Unverified Commit 3324283c authored by moto's avatar moto Committed by GitHub
Browse files

Add TorchScript-able "save" func to sox_io backend (#732)

This is a part of PRs to add new "sox_io" backend. #726 and depends on #718, #728 and #731.

This PR adds `save` function to "sox_io" backend, which can save Tensor to a file with the following audio formats;
 - `wav`
 - `mp3`
 - `flac`
 - `ogg/vorbis`
parent ea42513f
......@@ -92,28 +92,34 @@ def set_audio_backend(backend):
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
base_temp_dir = None
temp_dir = None
def setUp(self):
super().setUp()
self._init_temp_dir()
@classmethod
def setUpClass(cls):
super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name
def tearDown(self):
@classmethod
def tearDownClass(cls):
super().tearDownClass()
self._clean_up_temp_dir()
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()
def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name
def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None
def setUp(self):
self.temp_dir = os.path.join(self.base_temp_dir, self.id())
def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)
path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
class TestBaseMixin:
......
......@@ -31,7 +31,16 @@ def gen_audio_file(
'Use get_wav_data and save_wav to generate wav file for accurate result.')
command = [
'sox',
'-V', # verbose
'-V3', # verbose
'-R',
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
# https://fossies.org/dox/sox-14.4.2/sox_8c_source.html
# search "sox_globals.repeatable"
]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
command += [
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
......@@ -60,7 +69,7 @@ def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V', str(src_path)]
command = ['sox', '-V3', '-R', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
......
......@@ -28,7 +28,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
......@@ -44,7 +44,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
path = self.get_temp_path('data.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
......@@ -60,7 +60,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
path = self.get_temp_path('data.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
......@@ -79,7 +79,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
path = self.get_temp_path('data.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
......@@ -97,7 +97,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
path = self.get_temp_path('data.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
......
......@@ -24,7 +24,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
......@@ -58,8 +58,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.mp3')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate mp3 with sox
sox_utils.gen_audio_file(
......@@ -80,8 +80,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.flac')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate flac with sox
sox_utils.gen_audio_file(
......@@ -102,8 +102,8 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis')
ref_path = f'{path}.wav'
path = self.get_temp_path('1.original.vorbis')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
......
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
)
@skipIfNoExec('sox')
@skipIfNoExtension
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.flac')
sox_io_backend.save(path, data, sample_rate, compression=compression_level)
data, sr = sox_io_backend.load(path)
assert sr == sample_rate
self.assertEqual(original, data)
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
load_wav,
save_wav,
)
from . import sox_utils
class SaveTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames):
"""`sox_io_backend.save` can save wav format."""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, num_channels, num_frames=num_frames)
sox_io_backend.save(path, expected, sample_rate)
found, sr = load_wav(path)
assert sample_rate == sr
self.assertEqual(found, expected)
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.save` can save mp3 format.
mp3 encoding introduces delay and boundary effects so
we convert the resulting mp3 to wav and compare the results there
|
| 1. Generate original wav file with SciPy
|
v
-------------- wav ----------------
| |
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox
| then save with torchaudio |
v v
mp3 mp3
| |
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox
| |
v v
wav wav
| |
| 2.3. load with scipy | 3.3. load with scipy
| |
v v
tensor -------> compare <--------- tensor
"""
src_path = self.get_temp_path('1.reference.wav')
mp3_path = self.get_temp_path('2.1.torchaudio.mp3')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
mp3_path_sox = self.get_temp_path('3.1.sox.mp3')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to mp3 with torchaudio
sox_io_backend.save(
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate)
# 2.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to mp3 with SoX
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate)
# 3.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.save` can save flac format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
flc_path = self.get_temp_path('2.1.torchaudio.flac')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
flc_path_sox = self.get_temp_path('3.1.sox.flac')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend.save(
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to flac with SoX
sox_utils.convert_audio_file(src_path, flc_path_sox, compression=compression_level)
# 3.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.save` can save vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
vbs_path = self.get_temp_path('2.1.torchaudio.vorbis')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
vbs_path_sox = self.get_temp_path('3.1.sox.vorbis')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to vorbis with torchaudio
sox_io_backend.save(
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level)
# 2.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vbs_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to vorbis with SoX
sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level)
# 3.2. Convert the vorbis to wav with Sox
sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
# sox's vorbis encoding has some random boundary effect, which cause small number of
# samples yields higher descrepency than the others.
# so we allow small portions of data to be outside of absolute torelance.
# make sure to pass somewhat long duration
atol = 1.0e-4
max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol.
failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel()
if failure_ratio > max_failure_allowed:
# it's failed and this will give a better error message.
self.assertEqual(found, expected, atol=atol, rtol=1.3e-6)
def assert_vorbis(self, *args, **kwargs):
# sox's vorbis encoding has some randomness, so we run tests multiple time
max_retry = 5
error = None
for _ in range(max_retry):
try:
self._assert_vorbis(*args, **kwargs)
break
except AssertionError as e:
error = e
else:
raise error
@skipIfNoExec('sox')
@skipIfNoExtension
class TestSave(SaveTestBase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterized.expand(list(itertools.product(
['float32'],
[16000],
[2],
)), name_func=get_test_name)
def test_wav_large(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save large wav file."""
two_hours = 2 * 60 * 60 * sample_rate
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save mp3 format."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[128],
)), name_func=get_test_name)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save large mp3 file."""
two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save flac format."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=get_test_name)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save large flac file."""
two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save vorbis format."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20)
# note: torchaudio can load large vorbis file, but cannot save large volbis file
# the following test causes Segmentation fault
#
'''
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=get_test_name)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''
@skipIfNoExec('sox')
@skipIfNoExtension
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
path = self.get_temp_path('data.wav')
data = get_wav_data('int32', 2, channels_first=channels_first)
sox_io_backend.save(
path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0]
expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8'
], name_func=get_test_name)
def test_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
assert not expected.is_contiguous()
sox_io_backend.save(path, expected, 8000)
found = load_wav(path)[0]
self.assertEqual(found, expected)
@parameterized.expand([
'float32', 'int32', 'int16', 'uint8',
])
def test_tensor_preserve(self, dtype):
"""save function should not alter Tensor"""
path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2]
data = expected.clone()
sox_io_backend.save(path, data, 8000)
self.assertEqual(data, expected)
import itertools
from typing import Optional
import torch
from torchaudio.backend import sox_io_backend
......@@ -13,8 +14,10 @@ from ..common_utils import (
from .common import (
get_test_name,
get_wav_data,
save_wav
save_wav,
load_wav,
)
from . import sox_utils
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
......@@ -26,6 +29,16 @@ def py_load_func(filepath: str, normalize: bool, channels_first: bool):
filepath, normalize=normalize, channels_first=channels_first)
def py_save_func(
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
):
sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression)
@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
......@@ -41,7 +54,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('info_func')
script_path = self.get_temp_path('info_func.zip')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)
......@@ -65,7 +78,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('load_func')
script_path = self.get_temp_path('load_func.zip')
torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path)
......@@ -76,3 +89,59 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_save_wav(self, dtype, sample_rate, num_channels):
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
expected = get_wav_data(dtype, num_channels)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
py_save_func(py_path, expected, sample_rate, True, None)
ts_save_func(ts_path, expected, sample_rate, True, None)
py_data, py_sr = load_wav(py_path)
ts_data, ts_sr = load_wav(ts_path)
self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_save_flac(self, sample_rate, num_channels, compression_level):
script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)
expected = get_wav_data('float32', num_channels)
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
py_save_func(py_path, expected, sample_rate, True, compression_level)
ts_save_func(ts_path, expected, sample_rate, True, compression_level)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav'
ts_path_wav = f'{ts_path}.wav'
sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32)
sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32)
py_data, py_sr = load_wav(py_path_wav, normalize=True)
ts_data, ts_sr = load_wav(ts_path_wav, normalize=True)
self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
self.assertEqual(expected, py_data)
self.assertEqual(expected, ts_data)
from typing import Tuple
from typing import Tuple, Optional
import torch
from torchaudio._internal import (
......@@ -25,12 +25,7 @@ def load(
This function can handle all the codecs that underlying libsox can handle, however note the
followings.
Note 1:
Current torchaudio's binary release only contains codecs for MP3, FLAC and OGG/VORBIS.
If you need other formats, you need to build torchaudio from source with libsox and
the corresponding codecs. Refer to README for this.
Note 2:
Note:
This function is tested on the following formats;
- WAV
- 32-bit floating-point
......@@ -77,4 +72,58 @@ def load(
return signal.get_tensor(), signal.get_sample_rate()
@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
tensor: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
frames_per_chunk: int = 65536,
):
"""Save audio data to file.
Supported formats are;
- WAV
- 32-bit floating-point
- 32-bit signed integer
- 16-bit signed integer
- 8-bit unsigned integer
- MP3
- FLAC
- OGG/VORBIS
Args:
filepath: Path to save file.
tensor: Audio data to save. must be 2D tensor.
sample_rate: sampling rate
channels_first: If True, the given tensor is interpreted as ``[channel, time]``.
compression: Used for formats other than WAV. This corresponds to ``-C`` option
of ``sox`` command.
See the detail at http://sox.sourceforge.net/soxformat.html.
- MP3: Either bitrate [kbps] with quality factor, such as ``128.2`` or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``
- FLAC: compression level. Whole number from ``0`` to ``8``.
``8`` is default and highest compression.
- OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest
quality. Default: ``3``.
frames_per_chunk: The number of frames to process (convert to ``int32`` internally
then write to file) at a time.
"""
if compression is None:
ext = str(filepath)[-3:].lower()
if ext == 'wav':
compression = 0.
elif ext == 'mp3':
compression = -4.5
elif ext == 'flac':
compression = 8.
elif ext in ['ogg', 'vorbis']:
compression = 3.
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk)
load_wav = load
......@@ -46,6 +46,14 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op(
decltype(sox_io::load_audio_file),
&sox_io::load_audio_file>());
static auto registerSaveAudioFile = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression, int frames_per_chunk) -> ()")
.catchAllKernel<
decltype(sox_io::save_audio_file),
&sox_io::save_audio_file>());
////////////////////////////////////////////////////////////////////////////////
// sox_effects.h
////////////////////////////////////////////////////////////////////////////////
......
......@@ -15,7 +15,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
/*encoding=*/nullptr,
/*filetype=*/nullptr));
if (sf.get() == nullptr) {
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error opening audio file");
}
......@@ -52,7 +52,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
const int64_t num_total_samples = sf->signal.length;
const int64_t sample_start = sf->signal.channels * frame_offset;
if (sox_seek(sf.get(), sample_start, 0) == SOX_EOF) {
if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
throw std::runtime_error("Error reading audio file: offset past EOF.");
}
......@@ -79,7 +79,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
// Read samples into buffer
std::vector<sox_sample_t> buffer;
buffer.reserve(max_samples);
const int64_t num_samples = sox_read(sf.get(), buffer.data(), max_samples);
const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
if (num_samples == 0) {
throw std::runtime_error(
"Error reading audio file: empty file or read operation failed.");
......@@ -100,5 +100,51 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
}
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression,
const int64_t frames_per_chunk) {
const auto tensor = signal->getTensor();
const auto sample_rate = signal->getSampleRate();
const auto channels_first = signal->getChannelsFirst();
validate_input_tensor(tensor);
const auto filetype = get_filetype(file_name);
const auto signal_info =
get_signalinfo(tensor, sample_rate, channels_first, filetype);
const auto encoding_info =
get_encodinginfo(filetype, tensor.dtype(), compression);
SoxFormat sf(sox_open_write(
file_name.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error saving audio file: failed to open file.");
}
auto tensor_ = tensor;
if (channels_first) {
tensor_ = tensor_.t();
}
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous();
const size_t numel = chunk.numel();
if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
throw std::runtime_error(
"Error saving audio file: failed to write the entier buffer.");
}
}
}
} // namespace sox_io
} // namespace torchaudio
......@@ -17,6 +17,11 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const bool normalize = true,
const bool channels_first = true);
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
const double compression = 0.,
const int64_t frames_per_chunk = 65536);
} // namespace sox_io
} // namespace torchaudio
......
......@@ -32,12 +32,12 @@ SoxFormat::~SoxFormat() {
sox_format_t* SoxFormat::operator->() const noexcept {
return fd_;
}
sox_format_t* SoxFormat::get() const noexcept {
SoxFormat::operator sox_format_t*() const noexcept {
return fd_;
}
void validate_input_file(const SoxFormat& sf) {
if (sf.get() == nullptr) {
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file.");
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
......@@ -48,6 +48,23 @@ void validate_input_file(const SoxFormat& sf) {
}
}
void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU.");
}
if (tensor.ndimension() != 2) {
throw std::runtime_error("Input tensor has to be 2D.");
}
const auto dtype = tensor.dtype();
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
dtype == torch::kInt16 || dtype == torch::kUInt8)) {
throw std::runtime_error(
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
}
}
caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision) {
......@@ -109,5 +126,120 @@ torch::Tensor convert_to_tensor(
return t.contiguous();
}
torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
const auto dtype = input_tensor.dtype();
auto tensor = input_tensor;
if (dtype == torch::kFloat32) {
double multi_pos = 2147483647.;
double multi_neg = -2147483648.;
auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
tensor = tensor.to(torch::dtype(torch::kFloat64));
tensor *= mult;
tensor.clamp_(multi_neg, multi_pos);
tensor = tensor.to(torch::dtype(torch::kInt32));
} else if (dtype == torch::kInt32) {
// already denormalized
} else if (dtype == torch::kInt16) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor *= ((tensor != 0) * 65536);
} else if (dtype == torch::kUInt8) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor -= 128;
tensor *= 16777216;
} else {
throw std::runtime_error("Unexpected dtype.");
}
return tensor;
}
const std::string get_filetype(const std::string path) {
std::string ext = path.substr(path.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext;
}
sox_encoding_t get_encoding(
const std::string filetype,
const caffe2::TypeMeta dtype) {
if (filetype == "mp3")
return SOX_ENCODING_MP3;
if (filetype == "flac")
return SOX_ENCODING_FLAC;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_ENCODING_VORBIS;
if (filetype == "wav") {
if (dtype == torch::kUInt8)
return SOX_ENCODING_UNSIGNED;
if (dtype == torch::kInt16)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kInt32)
return SOX_ENCODING_SIGN2;
if (dtype == torch::kFloat32)
return SOX_ENCODING_FLOAT;
throw std::runtime_error("Unsupported dtype.");
}
throw std::runtime_error("Unsupported file type.");
}
unsigned get_precision(
const std::string filetype,
const caffe2::TypeMeta dtype) {
if (filetype == "mp3")
return SOX_UNSPEC;
if (filetype == "flac")
return 24;
if (filetype == "ogg" || filetype == "vorbis")
return SOX_UNSPEC;
if (filetype == "wav") {
if (dtype == torch::kUInt8)
return 8;
if (dtype == torch::kInt16)
return 16;
if (dtype == torch::kInt32)
return 32;
if (dtype == torch::kFloat32)
return 32;
throw std::runtime_error("Unsupported dtype.");
}
throw std::runtime_error("Unsupported file type.");
}
sox_signalinfo_t get_signalinfo(
const torch::Tensor& tensor,
const int64_t sample_rate,
const bool channels_first,
const std::string filetype) {
return sox_signalinfo_t{
/*rate=*/static_cast<sox_rate_t>(sample_rate),
/*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)),
/*precision=*/get_precision(filetype, tensor.dtype()),
/*length=*/static_cast<uint64_t>(tensor.numel())};
}
sox_encodinginfo_t get_encodinginfo(
const std::string filetype,
const caffe2::TypeMeta dtype,
const double compression) {
const double compression_ = [&]() {
if (filetype == "mp3")
return compression;
if (filetype == "flac")
return compression;
if (filetype == "ogg" || filetype == "vorbis")
return compression;
if (filetype == "wav")
return 0.;
throw std::runtime_error("Unsupported file type.");
}();
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
/*bits_per_sample=*/get_precision(filetype, dtype),
/*compression=*/compression_,
/*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default,
/*reverse_bits=*/sox_option_default,
/*opposite_endian=*/sox_false};
}
} // namespace sox_utils
} // namespace torchaudio
......@@ -31,7 +31,7 @@ struct SoxFormat {
SoxFormat& operator=(SoxFormat&& other) = delete;
~SoxFormat();
sox_format_t* operator->() const noexcept;
sox_format_t* get() const noexcept;
operator sox_format_t*() const noexcept;
private:
sox_format_t* fd_;
......@@ -41,6 +41,10 @@ struct SoxFormat {
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const torch::Tensor);
///
/// Get target dtype for the given encoding and precision.
caffe2::TypeMeta get_dtype(
......@@ -70,6 +74,27 @@ torch::Tensor convert_to_tensor(
const bool normalize,
const bool channels_first);
///
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
/// conversion.
torch::Tensor unnormalize_wav(const torch::Tensor);
/// Extract extension from file path
const std::string get_filetype(const std::string path);
/// Get sox_signalinfo_t for passing a torch::Tensor object.
sox_signalinfo_t get_signalinfo(
const torch::Tensor& tensor,
const int64_t sample_rate,
const bool channels_first,
const std::string filetype);
/// Get sox_encofinginfo_t for saving audoi file
sox_encodinginfo_t get_encodinginfo(
const std::string filetype,
const caffe2::TypeMeta dtype,
const double compression);
} // namespace sox_utils
} // namespace torchaudio
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment