Unverified Commit c3cb2015 authored by moto's avatar moto Committed by GitHub
Browse files

Add encoding and bits_per_sample option to save function (#1226)

parent 4f9b5520
def name_func(func, _, params): def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}' return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
def get_enc_params(dtype):
if dtype == 'float32':
return 'PCM_F', 32
if dtype == 'int32':
return 'PCM_S', 32
if dtype == 'int16':
return 'PCM_S', 16
if dtype == 'uint8':
return 'PCM_U', 8
raise ValueError(f'Unexpected dtype: {dtype}')
...@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -12,6 +12,7 @@ from torchaudio_unittest.common_utils import (
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params,
) )
...@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase): ...@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats""" """save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False) original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original data = original
for i in range(10): for i in range(10):
path = self.get_temp_path(f'{i}.wav') path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate) sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False) data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(original, data) self.assertEqual(original, data)
......
import io import io
import itertools import unittest
from itertools import product
import torch
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase,
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
...@@ -17,37 +18,62 @@ from torchaudio_unittest.common_utils import ( ...@@ -17,37 +18,62 @@ from torchaudio_unittest.common_utils import (
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params,
) )
class SaveTestBase(TempDirMixin, PytorchTestCase): def _get_sox_encoding(encoding):
def assert_wav(self, dtype, sample_rate, num_channels, num_frames): encodings = {
"""`sox_io_backend.save` can save wav format.""" 'PCM_F': 'floating-point',
path = self.get_temp_path('data.wav') 'PCM_S': 'signed-integer',
expected = get_wav_data(dtype, num_channels, num_frames=num_frames) 'PCM_U': 'unsigned-integer',
sox_io_backend.save(path, expected, sample_rate, dtype=None) 'ULAW': 'u-law',
found, sr = load_wav(path) 'ALAW': 'a-law',
assert sample_rate == sr }
self.assertEqual(found, expected) return encodings.get(encoding)
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.save` can save mp3 format. class SaveTestBase(TempDirMixin, TorchaudioTestCase):
def assert_save_consistency(
mp3 encoding introduces delay and boundary effects so self,
we convert the resulting mp3 to wav and compare the results there format: str,
*,
| compression: float = None,
| 1. Generate original wav file with SciPy encoding: str = None,
bits_per_sample: int = None,
sample_rate: float = 8000,
num_channels: int = 2,
num_frames: float = 3 * 8000,
src_dtype: str = 'int32',
test_mode: str = "path",
):
"""`save` function produces file that is comparable with `sox` command
To compare that the file produced by `save` function agains the file produced by
the equivalent `sox` command, we need to load both files.
But there are many formats that cannot be opened with common Python modules (like
SciPy).
So we use `sox` command to prepare the original data and convert the saved files
into a format that SciPy can read (PCM wav).
The following diagram illustrates this process. The difference is 2.1. and 3.1.
This assumes that
- loading data with SciPy preserves the data well.
- converting the resulting files into WAV format with `sox` preserve the data well.
x
| 1. Generate source wav file with SciPy
| |
v v
-------------- wav ---------------- -------------- wav ----------------
| | | |
| 2.1. load with scipy | 3.1. Convert to mp3 with Sox | 2.1. load with scipy | 3.1. Convert to the target
| then save with torchaudio | | then save it into the target | format depth with sox
| format with torchaudio |
v v v v
mp3 mp3 target format target format
| | | |
| 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox | 2.2. Convert to wav with sox | 3.2. Convert to wav with sox
| | | |
v v v v
wav wav wav wav
...@@ -58,326 +84,260 @@ class SaveTestBase(TempDirMixin, PytorchTestCase): ...@@ -58,326 +84,260 @@ class SaveTestBase(TempDirMixin, PytorchTestCase):
tensor -------> compare <--------- tensor tensor -------> compare <--------- tensor
""" """
src_path = self.get_temp_path('1.reference.wav') cmp_encoding = 'floating-point'
mp3_path = self.get_temp_path('2.1.torchaudio.mp3') cmp_bit_depth = 32
wav_path = self.get_temp_path('2.2.torchaudio.wav')
mp3_path_sox = self.get_temp_path('3.1.sox.mp3')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav src_path = self.get_temp_path('1.source.wav')
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) tgt_path = self.get_temp_path(f'2.1.torchaudio.{format}')
save_wav(src_path, data, sample_rate) tst_path = self.get_temp_path('2.2.result.wav')
# 2.1. Convert the original wav to mp3 with torchaudio sox_path = self.get_temp_path(f'3.1.sox.{format}')
sox_io_backend.save( ref_path = self.get_temp_path('3.2.ref.wav')
mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate, dtype=None)
# 2.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to mp3 with SoX
sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate)
# 3.2. Convert the mp3 to wav with Sox
sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.save` can save flac format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
flc_path = self.get_temp_path('2.1.torchaudio.flac')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
flc_path_sox = self.get_temp_path('3.1.sox.flac')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav # 1. Generate original wav
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate) save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to flac with torchaudio
sox_io_backend.save(
flc_path, load_wav(src_path)[0], sample_rate, compression=compression_level, dtype=None)
# 2.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to flac with SoX
sox_utils.convert_audio_file(src_path, flc_path_sox, compression=compression_level)
# 3.2. Convert the flac to wav with Sox
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.save` can save vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
src_path = self.get_temp_path('1.reference.wav')
vbs_path = self.get_temp_path('2.1.torchaudio.vorbis')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
vbs_path_sox = self.get_temp_path('3.1.sox.vorbis')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav # 2.1. Convert the original wav to target format with torchaudio
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate) data = load_wav(src_path, normalize=False)[0]
save_wav(src_path, data, sample_rate) if test_mode == "path":
# 2.1. Convert the original wav to vorbis with torchaudio sox_io_backend.save(
sox_io_backend.save( tgt_path, data, sample_rate,
vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level, dtype=None) compression=compression, encoding=encoding, bits_per_sample=bits_per_sample)
# 2.2. Convert the vorbis to wav with Sox elif test_mode == "fileobj":
sox_utils.convert_audio_file(vbs_path, wav_path) with open(tgt_path, 'bw') as file_:
# 2.3. Load sox_io_backend.save(
found = load_wav(wav_path)[0] file_, data, sample_rate,
format=format, compression=compression,
# 3.1. Convert the original wav to vorbis with SoX encoding=encoding, bits_per_sample=bits_per_sample)
sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level) elif test_mode == "bytesio":
# 3.2. Convert the vorbis to wav with Sox file_ = io.BytesIO()
sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox) sox_io_backend.save(
# 3.3. Load file_, data, sample_rate,
expected = load_wav(wav_path_sox)[0] format=format, compression=compression,
encoding=encoding, bits_per_sample=bits_per_sample)
# sox's vorbis encoding has some random boundary effect, which cause small number of file_.seek(0)
# samples yields higher descrepency than the others. with open(tgt_path, 'bw') as f:
# so we allow small portions of data to be outside of absolute torelance. f.write(file_.read())
# make sure to pass somewhat long duration
atol = 1.0e-4
max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol.
failure_ratio = ((found - expected).abs() > atol).sum().item() / found.numel()
if failure_ratio > max_failure_allowed:
# it's failed and this will give a better error message.
self.assertEqual(found, expected, atol=atol, rtol=1.3e-6)
def assert_vorbis(self, *args, **kwargs):
# sox's vorbis encoding has some randomness, so we run tests multiple time
max_retry = 5
error = None
for _ in range(max_retry):
try:
self._assert_vorbis(*args, **kwargs)
break
except AssertionError as e:
error = e
else: else:
raise error raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with sox
def assert_sphere(self, sample_rate, num_channels, duration): sox_utils.convert_audio_file(
"""`sox_io_backend.save` can save sph format. tgt_path, tst_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
# 2.3. Load with SciPy
This test takes the same strategy as mp3 to compare the result found = load_wav(tst_path, normalize=False)[0]
"""
src_path = self.get_temp_path('1.reference.wav') # 3.1. Convert the original wav to target format with sox
flc_path = self.get_temp_path('2.1.torchaudio.sph') sox_encoding = _get_sox_encoding(encoding)
wav_path = self.get_temp_path('2.2.torchaudio.wav') sox_utils.convert_audio_file(
flc_path_sox = self.get_temp_path('3.1.sox.sph') src_path, sox_path,
wav_path_sox = self.get_temp_path('3.2.sox.wav') compression=compression, encoding=sox_encoding, bit_depth=bits_per_sample)
# 3.2. Convert the target format to wav with sox
# 1. Generate original wav sox_utils.convert_audio_file(
data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) sox_path, ref_path, encoding=cmp_encoding, bit_depth=cmp_bit_depth)
save_wav(src_path, data, sample_rate) # 3.3. Load with SciPy
# 2.1. Convert the original wav to sph with torchaudio expected = load_wav(ref_path, normalize=False)[0]
sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate, dtype=None)
# 2.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to sph with SoX
sox_utils.convert_audio_file(src_path, flc_path_sox)
# 3.2. Convert the sph to wav with Sox
# converting to 32 bit because sph file has 24 bit depth which scipy cannot handle.
sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected) self.assertEqual(found, expected)
def assert_amb(self, dtype, sample_rate, num_channels, duration):
"""`sox_io_backend.save` can save amb format.
This test takes the same strategy as mp3 to compare the result def nested_params(*params):
""" def _name_func(func, _, params):
src_path = self.get_temp_path('1.reference.wav') strs = []
amb_path = self.get_temp_path('2.1.torchaudio.amb') for arg in params.args:
wav_path = self.get_temp_path('2.2.torchaudio.wav') if isinstance(arg, tuple):
amb_path_sox = self.get_temp_path('3.1.sox.amb') strs.append("_".join(str(a) for a in arg))
wav_path_sox = self.get_temp_path('3.2.sox.wav') else:
strs.append(str(arg))
return f'{func.__name__}_{"_".join(strs)}'
# 1. Generate original wav return parameterized.expand(
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) list(product(*params)),
save_wav(src_path, data, sample_rate) name_func=_name_func
# 2.1. Convert the original wav to amb with torchaudio )
sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
# 2.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to amb with SoX
sox_utils.convert_audio_file(src_path, amb_path_sox)
# 3.2. Convert the amb to wav with Sox
sox_utils.convert_audio_file(amb_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
def assert_amr_nb(self, duration):
"""`sox_io_backend.save` can save amr_nb format.
This test takes the same strategy as mp3 to compare the result
"""
sample_rate = 8000
num_channels = 1
src_path = self.get_temp_path('1.reference.wav')
amr_path = self.get_temp_path('2.1.torchaudio.amr-nb')
wav_path = self.get_temp_path('2.2.torchaudio.wav')
amr_path_sox = self.get_temp_path('3.1.sox.amr-nb')
wav_path_sox = self.get_temp_path('3.2.sox.wav')
# 1. Generate original wav
data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(src_path, data, sample_rate)
# 2.1. Convert the original wav to amr_nb with torchaudio
sox_io_backend.save(amr_path, load_wav(src_path, normalize=False)[0], sample_rate, dtype=None)
# 2.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path, wav_path)
# 2.3. Load
found = load_wav(wav_path)[0]
# 3.1. Convert the original wav to amr_nb with SoX
sox_utils.convert_audio_file(src_path, amr_path_sox)
# 3.2. Convert the amr_nb to wav with Sox
sox_utils.convert_audio_file(amr_path_sox, wav_path_sox)
# 3.3. Load
expected = load_wav(wav_path_sox)[0]
self.assertEqual(found, expected)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
class TestSave(SaveTestBase): class SaveTest(SaveTestBase):
@parameterized.expand(list(itertools.product( @nested_params(
['float32', 'int32', 'int16', 'uint8'], ["path", "fileobj", "bytesio"],
[8000, 16000], [
[1, 2], ('PCM_U', 8),
)), name_func=name_func) ('PCM_S', 16),
def test_wav(self, dtype, sample_rate, num_channels): ('PCM_S', 32),
"""`sox_io_backend.save` can save wav format.""" ('PCM_F', 32),
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) ('PCM_F', 64),
('ULAW', 8),
@parameterized.expand(list(itertools.product( ('ALAW', 8),
['float32'], ],
[16000], )
[2], def test_save_wav(self, test_mode, enc_params):
)), name_func=name_func) encoding, bits_per_sample = enc_params
def test_wav_large(self, dtype, sample_rate, num_channels): self.assert_save_consistency(
"""`sox_io_backend.save` can save large wav file.""" "wav", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
two_hours = 2 * 60 * 60 * sample_rate
self.assert_wav(dtype, sample_rate, num_channels, num_frames=two_hours) @nested_params(
["path", "fileobj", "bytesio"],
@parameterized.expand(list(itertools.product( [
['float32', 'int32', 'int16', 'uint8'], ('float32', ),
[4, 8, 16, 32], ('int32', ),
)), name_func=name_func) ('int16', ),
def test_multiple_channels(self, dtype, num_channels): ('uint8', ),
"""`sox_io_backend.save` can save wav with more than 2 channels.""" ],
)
def test_save_wav_dtype(self, test_mode, params):
dtype, = params
self.assert_save_consistency(
"wav", src_dtype=dtype, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
None,
-4.2,
-0.2,
0,
0.2,
96,
128,
160,
192,
224,
256,
320,
],
)
def test_save_mp3(self, test_mode, bit_rate):
if test_mode in ["fileobj", "bytesio"]:
if bit_rate is not None and bit_rate < 1:
raise unittest.SkipTest(
"mp3 format with variable bit rate is known to "
"not yield the exact same result as sox command.")
self.assert_save_consistency(
"mp3", compression=bit_rate, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[8, 16, 24],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
8,
],
)
def test_save_flac(self, test_mode, bits_per_sample, compression_level):
self.assert_save_consistency(
"flac", compression=compression_level,
bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
None,
-1,
0,
1,
2,
3,
3.6,
5,
10,
],
)
def test_save_vorbis(self, test_mode, quality_level):
self.assert_save_consistency(
"vorbis", compression=quality_level, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_S', 8, ),
('PCM_S', 16, ),
('PCM_S', 24, ),
('PCM_S', 32, ),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
],
)
def test_save_sphere(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"sph", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
('PCM_U', 8, ),
('PCM_S', 16, ),
('PCM_S', 24, ),
('PCM_S', 32, ),
('PCM_F', 32, ),
('PCM_F', 64, ),
('ULAW', 8, ),
('ALAW', 8, ),
],
)
def test_save_amb(self, test_mode, enc_params):
encoding, bits_per_sample = enc_params
self.assert_save_consistency(
"amb", encoding=encoding, bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
[
None,
0,
1,
2,
3,
4,
5,
6,
7,
],
)
def test_save_amr_nb(self, test_mode, bit_rate):
self.assert_save_consistency(
"amr-nb", compression=bit_rate, num_channels=1, test_mode=test_mode)
@parameterized.expand([
("wav", "PCM_S", 16),
("mp3", ),
("flac", ),
("vorbis", ),
("sph", "PCM_S", 16),
("amr-nb", ),
("amb", "PCM_S", 16),
], name_func=name_func)
def test_save_large(self, format, encoding=None, bits_per_sample=None):
"""`sox_io_backend.save` can save large files."""
sample_rate = 8000 sample_rate = 8000
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) one_hour = 60 * 60 * sample_rate
self.assert_save_consistency(
@parameterized.expand(list(itertools.product( format, num_channels=1, sample_rate=8000, num_frames=one_hour,
[8000, 16000], encoding=encoding, bits_per_sample=bits_per_sample)
[1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], @parameterized.expand([
)), name_func=name_func) (32, ),
def test_mp3(self, sample_rate, num_channels, bit_rate): (64, ),
"""`sox_io_backend.save` can save mp3 format.""" (128, ),
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) (256, ),
], name_func=name_func)
@parameterized.expand(list(itertools.product( def test_save_multi_channels(self, num_channels):
[16000], """`sox_io_backend.save` can save audio with many channels"""
[2], self.assert_save_consistency(
[128], "wav", encoding="PCM_S", bits_per_sample=16,
)), name_func=name_func) num_channels=num_channels)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save large mp3 file."""
two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[None] + list(range(9)),
)), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save flac format."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=name_func)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save large flac file."""
two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, duration=two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[None, -1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save vorbis format."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20)
# note: torchaudio can load large vorbis file, but cannot save large volbis file
# the following test causes Segmentation fault
#
'''
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=name_func)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
'''
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.save` can save sph format."""
self.assert_sphere(sample_rate, num_channels, duration=1)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save amb format."""
self.assert_amb(dtype, sample_rate, num_channels, duration=1)
def test_amr_nb(self):
"""`sox_io_backend.save` can save amr-nb format."""
self.assert_amr_nb(duration=1)
@skipIfNoExec('sox') @skipIfNoExec('sox')
...@@ -385,136 +345,40 @@ class TestSave(SaveTestBase): ...@@ -385,136 +345,40 @@ class TestSave(SaveTestBase):
class TestSaveParams(TempDirMixin, PytorchTestCase): class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`""" """Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=name_func) @parameterized.expand([(True, ), (False, )], name_func=name_func)
def test_channels_first(self, channels_first): def test_save_channels_first(self, channels_first):
"""channels_first swaps axes""" """channels_first swaps axes"""
path = self.get_temp_path('data.wav') path = self.get_temp_path('data.wav')
data = get_wav_data('int32', 2, channels_first=channels_first) data = get_wav_data(
'int16', 2, channels_first=channels_first, normalize=False)
sox_io_backend.save( sox_io_backend.save(
path, data, 8000, channels_first=channels_first, dtype=None) path, data, 8000, channels_first=channels_first)
found = load_wav(path)[0] found = load_wav(path, normalize=False)[0]
expected = data if channels_first else data.transpose(1, 0) expected = data if channels_first else data.transpose(1, 0)
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand([
'float32', 'int32', 'int16', 'uint8' 'float32', 'int32', 'int16', 'uint8'
], name_func=name_func) ], name_func=name_func)
def test_noncontiguous(self, dtype): def test_save_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly""" """Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav') path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2] enc, bps = get_enc_params(dtype)
expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
assert not expected.is_contiguous() assert not expected.is_contiguous()
sox_io_backend.save(path, expected, 8000, dtype=None) sox_io_backend.save(
found = load_wav(path)[0] path, expected, 8000, encoding=enc, bits_per_sample=bps)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected) self.assertEqual(found, expected)
@parameterized.expand([ @parameterized.expand([
'float32', 'int32', 'int16', 'uint8', 'float32', 'int32', 'int16', 'uint8',
]) ])
def test_tensor_preserve(self, dtype): def test_save_tensor_preserve(self, dtype):
"""save function should not alter Tensor""" """save function should not alter Tensor"""
path = self.get_temp_path('data.wav') path = self.get_temp_path('data.wav')
expected = get_wav_data(dtype, 4)[::2, ::2] expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2]
data = expected.clone() data = expected.clone()
sox_io_backend.save(path, data, 8000, dtype=None) sox_io_backend.save(path, data, 8000)
self.assertEqual(data, expected) self.assertEqual(data, expected)
@parameterized.expand([
('float32', torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32)),
('int32', torch.tensor([-2147483648, -1073741824, 0, 1073741824, 2147483647]).to(torch.int32)),
('int16', torch.tensor([-32768, -16384, 0, 16384, 32767]).to(torch.int16)),
('uint8', torch.tensor([0, 64, 128, 192, 255]).to(torch.uint8)),
])
def test_dtype_conversion(self, dtype, expected):
"""`save` performs dtype conversion on float32 src tensors only."""
path = self.get_temp_path("data.wav")
data = torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32).view(-1, 1)
sox_io_backend.save(path, data, 8000, dtype=dtype)
found = load_wav(path, normalize=False)[0]
self.assertEqual(found, expected.view(-1, 1))
@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(SaveTestBase):
"""
We campare the result of file-like object input against file path input because
`save` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Saving audio to file object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True
data = get_wav_data(dtype, num_channels, num_frames=num_frames)
ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, dtype=None)
with open(res_path, 'wb') as fileobj:
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)
assert sample_rate == sr
self.assertEqual(expected_data, data)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Saving audio to BytesIO object returns the same result as via file path."""
sample_rate = 16000
dtype = 'float32'
num_channels = 2
num_frames = 16000
channels_first = True
data = get_wav_data(dtype, num_channels, num_frames=num_frames)
ref_path = self.get_temp_path(f'reference.{ext}')
res_path = self.get_temp_path(f'test.{ext}')
sox_io_backend.save(
ref_path, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, dtype=None)
fileobj = io.BytesIO()
sox_io_backend.save(
fileobj, data, channels_first=channels_first,
sample_rate=sample_rate, compression=compression, format=ext, dtype=None)
fileobj.seek(0)
with open(res_path, 'wb') as file_:
file_.write(fileobj.read())
expected_data, _ = sox_io_backend.load(ref_path)
data, sr = sox_io_backend.load(res_path)
assert sample_rate == sr
self.assertEqual(expected_data, data)
...@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -17,6 +17,7 @@ from torchaudio_unittest.common_utils import (
) )
from .common import ( from .common import (
name_func, name_func,
get_enc_params,
) )
...@@ -35,8 +36,12 @@ def py_save_func( ...@@ -35,8 +36,12 @@ def py_save_func(
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
): ):
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression) torchaudio.save(
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)
@skipIfNoExec('sox') @skipIfNoExec('sox')
...@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -102,15 +107,16 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
torch.jit.script(py_save_func).save(script_path) torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path) ts_save_func = torch.jit.load(script_path)
expected = get_wav_data(dtype, num_channels) expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
enc, bps = get_enc_params(dtype)
py_save_func(py_path, expected, sample_rate, True, None) py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
ts_save_func(ts_path, expected, sample_rate, True, None) ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)
py_data, py_sr = load_wav(py_path) py_data, py_sr = load_wav(py_path, normalize=False)
ts_data, ts_sr = load_wav(ts_path) ts_data, ts_sr = load_wav(ts_path, normalize=False)
self.assertEqual(sample_rate, py_sr) self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr) self.assertEqual(sample_rate, ts_sr)
...@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -131,8 +137,8 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac') py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac') ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
py_save_func(py_path, expected, sample_rate, True, compression_level) py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level) ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav' py_path_wav = f'{py_path}.wav'
......
import sys
import subprocess import subprocess
import warnings import warnings
...@@ -32,6 +33,7 @@ def gen_audio_file( ...@@ -32,6 +33,7 @@ def gen_audio_file(
command = [ command = [
'sox', 'sox',
'-V3', # verbose '-V3', # verbose
'--no-dither', # disable automatic dithering
'-R', '-R',
# -R is supposed to be repeatable, though the implementation looks suspicious # -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value. # and not setting the seed to a fixed value.
...@@ -61,21 +63,23 @@ def gen_audio_file( ...@@ -61,21 +63,23 @@ def gen_audio_file(
] ]
if attenuation is not None: if attenuation is not None:
command += ['vol', f'-{attenuation}dB'] command += ['vol', f'-{attenuation}dB']
print(' '.join(command)) print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True) subprocess.run(command, check=True)
def convert_audio_file( def convert_audio_file(
src_path, dst_path, src_path, dst_path,
*, bit_depth=None, compression=None): *, encoding=None, bit_depth=None, compression=None):
"""Convert audio file with `sox` command.""" """Convert audio file with `sox` command."""
command = ['sox', '-V3', '-R', str(src_path)] command = ['sox', '-V3', '--no-dither', '-R', str(src_path)]
if encoding is not None:
command += ['--encoding', str(encoding)]
if bit_depth is not None: if bit_depth is not None:
command += ['--bits', str(bit_depth)] command += ['--bits', str(bit_depth)]
if compression is not None: if compression is not None:
command += ['--compression', str(compression)] command += ['--compression', str(compression)]
command += [dst_path] command += [dst_path]
print(' '.join(command)) print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True) subprocess.run(command, check=True)
......
import os import os
import warnings
from typing import Tuple, Optional from typing import Tuple, Optional
import torch import torch
...@@ -152,26 +151,6 @@ def load( ...@@ -152,26 +151,6 @@ def load(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format)
@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format, dtype)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
def save( def save(
filepath: str, filepath: str,
...@@ -180,30 +159,11 @@ def save( ...@@ -180,30 +159,11 @@ def save(
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
format: Optional[str] = None, format: Optional[str] = None,
dtype: Optional[str] = None, encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
): ):
"""Save audio data to file. """Save audio data to file.
Note:
Supported formats are;
* WAV, AMB
* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer
* MP3
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args: Args:
filepath (str or pathlib.Path): Path to save file. filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated This function also handles ``pathlib.Path`` objects, but is annotated
...@@ -215,32 +175,137 @@ def save( ...@@ -215,32 +175,137 @@ def save(
compression (Optional[float]): Used for formats other than WAV. compression (Optional[float]): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command. This corresponds to ``-C`` option of ``sox`` command.
* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or ``"mp3"``
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``. Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``. VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
| ``8`` is default and highest compression.
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression ``"flac"``
| and lowest quality. Default: ``3``. Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.
See the detail at http://sox.sourceforge.net/soxformat.html. See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional): Output audio format. format (str, optional): Override the audio format.
This is required when the output audio format cannot be infered from When ``filepath`` argument is path-like object, audio format is infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object). file extension. If file extension is missing or different, you can specify the
dtype (str, optional): Output tensor dtype. correct format with this argument.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed. When ``filepath`` argument is file-like object, this argument is required.
``dtype`` parameter is only effective for ``float32`` Tensor.
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)
Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``
- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise
``"sph"`` format;
- the default value is ``"PCM_S"``
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``
``"flac"`` format;
- the default value is ``24``
``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``
Supported formats/encodings/bit depth/compression are;
``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.
``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)
``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps
``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law
``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
""" """
if src.dtype == torch.float32 and dtype is None:
warnings.warn(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format, dtype) if hasattr(filepath, 'write'):
return torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression,
format, encoding, bits_per_sample)
return
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file( torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, dtype) filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
......
...@@ -9,6 +9,7 @@ set( ...@@ -9,6 +9,7 @@ set(
sox/utils.cpp sox/utils.cpp
sox/effects.cpp sox/effects.cpp
sox/effects_chain.cpp sox/effects_chain.cpp
sox/types.cpp
) )
if(BUILD_TRANSDUCER) if(BUILD_TRANSDUCER)
......
...@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -68,21 +68,43 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Ensure that it's a multiple of the number of channels // Ensure that it's a multiple of the number of channels
*osamp -= *osamp % num_channels; *osamp -= *osamp % num_channels;
// Slice the input Tensor and unnormalize the values // Slice the input Tensor
const auto tensor_ = [&]() { const auto tensor_ = [&]() {
auto i_frame = index / num_channels; auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels; auto num_frames = *osamp / num_channels;
auto t = (priv->channels_first) auto t = (priv->channels_first)
? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t() ? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t()
: tensor.index({Slice(i_frame, i_frame + num_frames), Slice()}); : tensor.index({Slice(i_frame, i_frame + num_frames), Slice()});
return unnormalize_wav(t.reshape({-1})).contiguous(); return t.reshape({-1}).contiguous();
}(); }();
priv->index += *osamp;
// Write data to SoxEffectsChain buffer.
auto ptr = tensor_.data_ptr<int32_t>();
std::copy(ptr, ptr + *osamp, obuf);
// Convert to sox_sample_t (int32_t) and write to buffer
SOX_SAMPLE_LOCALS;
const auto dtype = tensor_.dtype();
if (dtype == torch::kFloat32) {
auto ptr = tensor_.data_ptr<float_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else if (dtype == torch::kInt32) {
auto ptr = tensor_.data_ptr<int32_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else if (dtype == torch::kInt16) {
auto ptr = tensor_.data_ptr<int16_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else if (dtype == torch::kUInt8) {
auto ptr = tensor_.data_ptr<uint8_t>();
for (size_t i = 0; i < *osamp; ++i) {
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
}
} else {
throw std::runtime_error("Unexpected dtype.");
}
priv->index += *osamp;
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;
} }
...@@ -430,7 +452,7 @@ int fileobj_output_flow( ...@@ -430,7 +452,7 @@ int fileobj_output_flow(
fflush(fp); fflush(fp);
// Copy the encoded chunk to python object. // Copy the encoded chunk to python object.
fileobj->attr("write")(py::bytes(*buffer, *buffer_size)); fileobj->attr("write")(py::bytes(*buffer, ftell(fp)));
// Reset FILE* // Reset FILE*
sf->tell_off = 0; sf->tell_off = 0;
......
...@@ -116,35 +116,27 @@ void save_audio_file( ...@@ -116,35 +116,27 @@ void save_audio_file(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
c10::optional<std::string> format, c10::optional<std::string>& format,
c10::optional<std::string> dtype) { c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample) {
validate_input_tensor(tensor); validate_input_tensor(tensor);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error(
"dtype conversion only supported for float32 tensors");
}
const auto tgt_dtype =
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();
const auto filetype = [&]() { const auto filetype = [&]() {
if (format.has_value()) if (format.has_value())
return format.value(); return format.value();
return get_filetype(path); return get_filetype(path);
}(); }();
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1); const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK( TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio."); num_channels == 1, "amr-nb format only supports single channel audio.");
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first); get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = const auto encoding_info = get_encodinginfo_for_save(
get_encodinginfo_for_save(filetype, tgt_dtype, compression); filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write( SoxFormat sf(sox_open_write(
path.c_str(), path.c_str(),
...@@ -258,19 +250,17 @@ void save_audio_fileobj( ...@@ -258,19 +250,17 @@ void save_audio_fileobj(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
std::string filetype, c10::optional<std::string>& format,
c10::optional<std::string> dtype) { c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample) {
validate_input_tensor(tensor); validate_input_tensor(tensor);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) { if (!format.has_value()) {
throw std::runtime_error( throw std::runtime_error(
"dtype conversion only supported for float32 tensors"); "`format` is required when saving to file object.");
} }
const auto tgt_dtype = const auto filetype = format.value();
(tensor.dtype() == torch::kFloat32 && dtype.has_value())
? get_dtype_from_str(dtype.value())
: tensor.dtype();
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1); const auto num_channels = tensor.size(channels_first ? 0 : 1);
...@@ -278,12 +268,11 @@ void save_audio_fileobj( ...@@ -278,12 +268,11 @@ void save_audio_fileobj(
throw std::runtime_error( throw std::runtime_error(
"amr-nb format only supports single channel audio."); "amr-nb format only supports single channel audio.");
} }
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first); get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = const auto encoding_info = get_encodinginfo_for_save(
get_encodinginfo_for_save(filetype, tgt_dtype, compression); filetype, tensor.dtype(), compression, encoding, bits_per_sample);
AutoReleaseBuffer buffer; AutoReleaseBuffer buffer;
......
...@@ -28,9 +28,10 @@ void save_audio_file( ...@@ -28,9 +28,10 @@ void save_audio_file(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
c10::optional<std::string> format, c10::optional<std::string>& format,
c10::optional<std::string> dtype); c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample);
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
...@@ -51,9 +52,10 @@ void save_audio_fileobj( ...@@ -51,9 +52,10 @@ void save_audio_fileobj(
torch::Tensor tensor, torch::Tensor tensor,
int64_t sample_rate, int64_t sample_rate,
bool channels_first, bool channels_first,
c10::optional<double> compression, c10::optional<double>& compression,
std::string filetype, c10::optional<std::string>& format,
c10::optional<std::string> dtype); c10::optional<std::string>& encoding,
c10::optional<int64_t>& bits_per_sample);
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
......
#include <torchaudio/csrc/sox/types.h>
namespace torchaudio {
namespace sox_utils {
Format get_format_from_string(const std::string& format) {
if (format == "wav")
return Format::WAV;
if (format == "mp3")
return Format::MP3;
if (format == "flac")
return Format::FLAC;
if (format == "ogg" || format == "vorbis")
return Format::VORBIS;
if (format == "amr-nb")
return Format::AMR_NB;
if (format == "amr-wb")
return Format::AMR_WB;
if (format == "amb")
return Format::AMB;
if (format == "sph")
return Format::SPHERE;
std::ostringstream stream;
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
}
std::string to_string(Encoding v) {
switch (v) {
case Encoding::UNKNOWN:
return "UNKNOWN";
case Encoding::PCM_SIGNED:
return "PCM_S";
case Encoding::PCM_UNSIGNED:
return "PCM_U";
case Encoding::PCM_FLOAT:
return "PCM_F";
case Encoding::FLAC:
return "FLAC";
case Encoding::ULAW:
return "ULAW";
case Encoding::ALAW:
return "ALAW";
case Encoding::MP3:
return "MP3";
case Encoding::VORBIS:
return "VORBIS";
case Encoding::AMR_WB:
return "AMR_WB";
case Encoding::AMR_NB:
return "AMR_NB";
case Encoding::OPUS:
return "OPUS";
default:
throw std::runtime_error("Internal Error: unexpected encoding.");
}
}
Encoding get_encoding_from_option(const c10::optional<std::string>& encoding) {
if (!encoding.has_value())
return Encoding::NOT_PROVIDED;
std::string v = encoding.value();
if (v == "PCM_S")
return Encoding::PCM_SIGNED;
if (v == "PCM_U")
return Encoding::PCM_UNSIGNED;
if (v == "PCM_F")
return Encoding::PCM_FLOAT;
if (v == "ULAW")
return Encoding::ULAW;
if (v == "ALAW")
return Encoding::ALAW;
std::ostringstream stream;
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
}
BitDepth get_bit_depth_from_option(const c10::optional<int64_t>& bit_depth) {
if (!bit_depth.has_value())
return BitDepth::NOT_PROVIDED;
int64_t v = bit_depth.value();
switch (v) {
case 8:
return BitDepth::B8;
case 16:
return BitDepth::B16;
case 24:
return BitDepth::B24;
case 32:
return BitDepth::B32;
case 64:
return BitDepth::B64;
default: {
std::ostringstream s;
s << "Internal Error: unexpected bit depth value: " << v;
throw std::runtime_error(s.str());
}
}
}
} // namespace sox_utils
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_TYPES_H
#define TORCHAUDIO_SOX_TYPES_H
#include <torch/script.h>
namespace torchaudio {
namespace sox_utils {
enum class Format {
WAV,
MP3,
FLAC,
VORBIS,
AMR_NB,
AMR_WB,
AMB,
SPHERE,
};
Format get_format_from_string(const std::string& format);
enum class Encoding {
NOT_PROVIDED,
UNKNOWN,
PCM_SIGNED,
PCM_UNSIGNED,
PCM_FLOAT,
FLAC,
ULAW,
ALAW,
MP3,
VORBIS,
AMR_WB,
AMR_NB,
OPUS,
};
std::string to_string(Encoding v);
Encoding get_encoding_from_option(const c10::optional<std::string>& encoding);
enum class BitDepth : unsigned {
NOT_PROVIDED = 0,
B8 = 8,
B16 = 16,
B24 = 24,
B32 = 32,
B64 = 64,
};
BitDepth get_bit_depth_from_option(const c10::optional<int64_t>& bit_depth);
} // namespace sox_utils
} // namespace torchaudio
#endif
#include <c10/core/ScalarType.h> #include <c10/core/ScalarType.h>
#include <sox.h> #include <sox.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h> #include <torchaudio/csrc/sox/utils.h>
namespace torchaudio { namespace torchaudio {
...@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor( ...@@ -163,22 +164,32 @@ torch::Tensor convert_to_tensor(
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
const bool normalize, const bool normalize,
const bool channels_first) { const bool channels_first) {
auto t = torch::from_blob( torch::Tensor t;
buffer, {num_samples / num_channels, num_channels}, torch::kInt32); uint64_t dummy;
// Note: Tensor created from_blob does not own data but borrwos SOX_SAMPLE_LOCALS;
// So make sure to create a new copy after processing samples.
if (normalize || dtype == torch::kFloat32) { if (normalize || dtype == torch::kFloat32) {
t = t.to(torch::kFloat32); t = torch::empty(
t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.; {num_samples / num_channels, num_channels}, torch::kFloat32);
auto ptr = t.data_ptr<float_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_FLOAT_32BIT(buffer[i], dummy);
}
} else if (dtype == torch::kInt32) { } else if (dtype == torch::kInt32) {
t = t.clone(); t = torch::from_blob(
buffer, {num_samples / num_channels, num_channels}, torch::kInt32)
.clone();
} else if (dtype == torch::kInt16) { } else if (dtype == torch::kInt16) {
t.floor_divide_(1 << 16); t = torch::empty({num_samples / num_channels, num_channels}, torch::kInt16);
t = t.to(torch::kInt16); auto ptr = t.data_ptr<int16_t>();
for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_SIGNED_16BIT(buffer[i], dummy);
}
} else if (dtype == torch::kUInt8) { } else if (dtype == torch::kUInt8) {
t.floor_divide_(1 << 24); t = torch::empty({num_samples / num_channels, num_channels}, torch::kUInt8);
t += 128; auto ptr = t.data_ptr<uint8_t>();
t = t.to(torch::kUInt8); for (int32_t i = 0; i < num_samples; ++i) {
ptr[i] = SOX_SAMPLE_TO_UNSIGNED_8BIT(buffer[i], dummy);
}
} else { } else {
throw std::runtime_error("Unsupported dtype."); throw std::runtime_error("Unsupported dtype.");
} }
...@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor( ...@@ -188,63 +199,181 @@ torch::Tensor convert_to_tensor(
return t.contiguous(); return t.contiguous();
} }
torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
const auto dtype = input_tensor.dtype();
auto tensor = input_tensor;
if (dtype == torch::kFloat32) {
double multi_pos = 2147483647.;
double multi_neg = -2147483648.;
auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
tensor = tensor.to(torch::dtype(torch::kFloat64));
tensor *= mult;
tensor.clamp_(multi_neg, multi_pos);
tensor = tensor.to(torch::dtype(torch::kInt32));
} else if (dtype == torch::kInt32) {
// already denormalized
} else if (dtype == torch::kInt16) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor *= ((tensor != 0) * 65536);
} else if (dtype == torch::kUInt8) {
tensor = tensor.to(torch::dtype(torch::kInt32));
tensor -= 128;
tensor *= 16777216;
} else {
throw std::runtime_error("Unexpected dtype.");
}
return tensor;
}
const std::string get_filetype(const std::string path) { const std::string get_filetype(const std::string path) {
std::string ext = path.substr(path.find_last_of(".") + 1); std::string ext = path.substr(path.find_last_of(".") + 1);
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext; return ext;
} }
sox_encoding_t get_encoding( namespace {
const std::string filetype,
const caffe2::TypeMeta dtype) { std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
if (filetype == "mp3") const std::string format,
return SOX_ENCODING_MP3; const caffe2::TypeMeta dtype,
if (filetype == "flac") const Encoding& encoding,
return SOX_ENCODING_FLAC; const BitDepth& bits_per_sample) {
if (filetype == "ogg" || filetype == "vorbis") switch (encoding) {
return SOX_ENCODING_VORBIS; case Encoding::NOT_PROVIDED:
if (filetype == "wav" || filetype == "amb") { switch (bits_per_sample) {
if (dtype == torch::kUInt8) case BitDepth::NOT_PROVIDED:
return SOX_ENCODING_UNSIGNED; if (dtype == torch::kFloat32)
if (dtype == torch::kInt16) return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
return SOX_ENCODING_SIGN2; if (dtype == torch::kInt32)
if (dtype == torch::kInt32) return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
return SOX_ENCODING_SIGN2; if (dtype == torch::kInt16)
if (dtype == torch::kFloat32) return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
return SOX_ENCODING_FLOAT; if (dtype == torch::kUInt8)
throw std::runtime_error("Unsupported dtype."); return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
throw std::runtime_error("Internal Error: Unexpected dtype.");
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_SIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8:
throw std::runtime_error(
format + " does not support 8-bit signed PCM encoding.");
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
case Encoding::PCM_UNSIGNED:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for unsigned PCM encoding.");
}
case Encoding::PCM_FLOAT:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B32:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default:
throw std::runtime_error(
format +
" only supports 32-bit or 64-bit for floating-point PCM encoding.");
}
case Encoding::ULAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
throw std::runtime_error(
format + " only supports 8-bit for a-law encoding.");
}
default:
throw std::runtime_error(
format + " does not support encoding: " + to_string(encoding));
}
}
std::tuple<sox_encoding_t, unsigned> get_save_encoding(
const std::string& format,
const caffe2::TypeMeta dtype,
const c10::optional<std::string>& encoding,
const c10::optional<int64_t>& bits_per_sample) {
const Format fmt = get_format_from_string(format);
const Encoding enc = get_encoding_from_option(encoding);
const BitDepth bps = get_bit_depth_from_option(bits_per_sample);
switch (fmt) {
case Format::WAV:
case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("mp3 does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("amr-nb does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("flac does not support `encoding` option.");
switch (bps) {
case BitDepth::B32:
case BitDepth::B64:
throw std::runtime_error(
"flac does not support `bits_per_sample` larger than 24.");
default:
return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
}
case Format::SPHERE:
switch (enc) {
case Encoding::NOT_PROVIDED:
case Encoding::PCM_SIGNED:
switch (bps) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
default:
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
}
case Encoding::PCM_UNSIGNED:
throw std::runtime_error(
"sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT:
throw std::runtime_error("sph does not support floating point PCM.");
case Encoding::ULAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
throw std::runtime_error(
"sph only supports 8-bit for mu-law encoding.");
}
case Encoding::ALAW:
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
return std::make_tuple<>(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
}
default:
throw std::runtime_error(
"sph does not support encoding: " + encoding.value());
}
default:
throw std::runtime_error("Unsupported format: " + format);
} }
if (filetype == "sph")
return SOX_ENCODING_SIGN2;
if (filetype == "amr-nb")
return SOX_ENCODING_AMR_NB;
throw std::runtime_error("Unsupported file type: " + filetype);
} }
unsigned get_precision( unsigned get_precision(
...@@ -270,14 +399,13 @@ unsigned get_precision( ...@@ -270,14 +399,13 @@ unsigned get_precision(
if (filetype == "sph") if (filetype == "sph")
return 32; return 32;
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
TORCH_INTERNAL_ASSERT(
dtype == torch::kInt16,
"When saving to AMR-NB format, the input tensor must be int16 type.");
return 16; return 16;
} }
throw std::runtime_error("Unsupported file type: " + filetype); throw std::runtime_error("Unsupported file type: " + filetype);
} }
} // namespace
sox_signalinfo_t get_signalinfo( sox_signalinfo_t get_signalinfo(
const torch::Tensor* waveform, const torch::Tensor* waveform,
const int64_t sample_rate, const int64_t sample_rate,
...@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { ...@@ -325,12 +453,15 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
} }
sox_encodinginfo_t get_encodinginfo_for_save( sox_encodinginfo_t get_encodinginfo_for_save(
const std::string filetype, const std::string& format,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression) { const c10::optional<double>& compression,
const c10::optional<std::string>& encoding,
const c10::optional<int64_t>& bits_per_sample) {
auto enc = get_save_encoding(format, dtype, encoding, bits_per_sample);
return sox_encodinginfo_t{ return sox_encodinginfo_t{
/*encoding=*/get_encoding(filetype, dtype), /*encoding=*/std::get<0>(enc),
/*bits_per_sample=*/get_precision(filetype, dtype), /*bits_per_sample=*/std::get<1>(enc),
/*compression=*/compression.value_or(HUGE_VAL), /*compression=*/compression.value_or(HUGE_VAL),
/*reverse_bytes=*/sox_option_default, /*reverse_bytes=*/sox_option_default,
/*reverse_nibbles=*/sox_option_default, /*reverse_nibbles=*/sox_option_default,
......
...@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor( ...@@ -93,11 +93,6 @@ torch::Tensor convert_to_tensor(
const bool normalize, const bool normalize,
const bool channels_first); const bool channels_first);
///
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
/// conversion.
torch::Tensor unnormalize_wav(const torch::Tensor);
/// Extract extension from file path /// Extract extension from file path
const std::string get_filetype(const std::string path); const std::string get_filetype(const std::string path);
...@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); ...@@ -113,9 +108,11 @@ sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype);
/// Get sox_encodinginfo_t for saving to file/file object /// Get sox_encodinginfo_t for saving to file/file object
sox_encodinginfo_t get_encodinginfo_for_save( sox_encodinginfo_t get_encodinginfo_for_save(
const std::string filetype, const std::string& format,
const caffe2::TypeMeta dtype, const caffe2::TypeMeta dtype,
c10::optional<double>& compression); const c10::optional<double>& compression,
const c10::optional<std::string>& encoding,
const c10::optional<int64_t>& bits_per_sample);
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment