Unverified Commit 80a8739a authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Refactor sox_io load_test (#1394)

parent 6bad3a66
...@@ -29,74 +29,30 @@ if _mod_utils.is_module_available("requests"): ...@@ -29,74 +29,30 @@ if _mod_utils.is_module_available("requests"):
class LoadTestBase(TempDirMixin, PytorchTestCase): class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): def assert_format(
"""`sox_io_backend.load` can load wav format correctly. self,
format: str,
Wav data loaded with sox_io backend should match those with scipy sample_rate: float,
""" num_channels: int,
path = self.get_temp_path('reference.wav') compression: float = None,
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) bit_depth: int = None,
save_wav(path, data, sample_rate) duration: float = 1,
expected = load_wav(path, normalize=normalize)[0] normalize: bool = True,
data, sr = sox_io_backend.load(path, normalize=normalize) encoding: str = None,
assert sr == sample_rate atol: float = 4e-05,
self.assertEqual(data, expected) rtol: float = 1.3e-06,
):
def assert_24bit_wav(self, sample_rate, num_channels, normalize, duration): """`sox_io_backend.load` can load given format correctly.
""" `sox_io_backend.load` can load 24-bit signed PCM wav format. Since torch does not support the ``int24`` dtype,
we implicitly cast the resulting tensor to the ``int32`` dtype. file encodings introduce delay and boundary effects so
we create a reference wav file from the original file format
It is not possible to use #assert_wav method above, as #get_wav_data does not support
the 'int24' dtype. This is because torch does not support the ``int24`` dtype.
Hence, we must use the following workaround.
x
|
| 1. Generate 24-bit wav with Sox.
|
v 2. Convert 24-bit wav to 32-bit wav with Sox.
wav(24-bit) ----------------------> wav(32-bit)
| |
| 3. Load 24-bit wav with torchaudio| 4. Load 32-bit wav with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
# Underlying assumptions are:
# i. Sox properly converts from 24-bit to 32-bit
# ii. Loading 32-bit wav file with scipy is correct.
"""
path = self.get_temp_path('1.original.wav')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate 24-bit signed wav with Sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=24, duration=duration)
# 2. Convert from 24-bit wav to 32-bit wav with sox
sox_utils.convert_audio_file(path, ref_path, bit_depth=32)
# 3. Load 24-bit wav with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load 32-bit wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.load` can load mp3 format.
mp3 encoding introduces delay and boundary effects so
we create reference wav file from mp3
x x
| |
| 1. Generate mp3 with Sox | 1. Generate given format with Sox
| |
v 2. Convert to wav with Sox v 2. Convert to wav with Sox
mp3 ------------------------------> wav given format ----------------------> wav
| | | |
| 3. Load with torchaudio | 4. Load with scipy | 3. Load with torchaudio | 4. Load with scipy
| | | |
...@@ -104,142 +60,45 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -104,142 +60,45 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
tensor ----------> x <----------- tensor tensor ----------> x <----------- tensor
5. Compare 5. Compare
Underlying assumptions are: Underlying assumptions are;
i. Conversion of mp3 to wav with Sox preserves data. i. Conversion of given format to wav with Sox preserves data.
ii. Loading wav file with scipy is correct. ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference mp3 data By combining i & ii, step 2. and 4. allows to load reference given format
without using torchaudio data without using torchaudio
""" """
path = self.get_temp_path('1.original.mp3')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate mp3 with sox path = self.get_temp_path(f'1.original.{format}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load mp3 with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.load` can load flac format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.flac')
ref_path = self.get_temp_path('2.reference.wav') ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate flac with sox # 1. Generate the given format with sox
sox_utils.gen_audio_file( sox_utils.gen_audio_file(
path, sample_rate, num_channels, path, sample_rate, num_channels, encoding=encoding,
compression=compression_level, bit_depth=16, duration=duration) compression=compression, bit_depth=bit_depth, duration=duration,
)
# 2. Convert to wav with sox # 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path) wav_bit_depth = 32 if bit_depth == 24 else None # for 24-bit wav
# 3. Load flac with torchaudio sox_utils.convert_audio_file(path, ref_path, bit_depth=wav_bit_depth)
data, sr = sox_io_backend.load(path) # 3. Load the given format with torchaudio
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.load` can load vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.vorbis')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, bit_depth=16, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load vorbis with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_sphere(self, sample_rate, num_channels, duration):
"""`sox_io_backend.load` can load sph format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.sph')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate sph with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load sph with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_amb(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load amb format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path('1.original.amb')
ref_path = self.get_temp_path('2.reference.wav')
# 1. Generate amb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
encoding=sox_utils.get_encoding(dtype),
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amb with torchaudio
data, sr = sox_io_backend.load(path, normalize=normalize) data, sr = sox_io_backend.load(path, normalize=normalize)
# 4. Load wav with scipy # 4. Load wav with scipy
data_ref = load_wav(ref_path, normalize=normalize)[0] data_ref = load_wav(ref_path, normalize=normalize)[0]
# 5. Compare # 5. Compare
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) self.assertEqual(data, data_ref, atol=atol, rtol=rtol)
def assert_amr_nb(self, duration): def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load amr-nb format. """`sox_io_backend.load` can load wav format correctly.
This test takes the same strategy as mp3 to compare the result Wav data loaded with sox_io backend should match those with scipy
""" """
sample_rate = 8000 path = self.get_temp_path('reference.wav')
num_channels = 1 data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
path = self.get_temp_path('1.original.amr-nb') save_wav(path, data, sample_rate)
ref_path = self.get_temp_path('2.reference.wav') expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
# 1. Generate amr-nb with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=32, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load amr-nb with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06) self.assertEqual(data, expected)
@skipIfNoExec('sox') @skipIfNoExec('sox')
...@@ -263,7 +122,7 @@ class TestLoad(LoadTestBase): ...@@ -263,7 +122,7 @@ class TestLoad(LoadTestBase):
)), name_func=name_func) )), name_func=name_func)
def test_24bit_wav(self, sample_rate, num_channels, normalize): def test_24bit_wav(self, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype.""" """`sox_io_backend.load` can load 24bit wav format correctly. Corectly casts it to ``int32`` tensor dtype."""
self.assert_24bit_wav(sample_rate, num_channels, normalize, duration=1) self.assert_format("wav", sample_rate, num_channels, bit_depth=24, normalize=normalize, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['int16'], ['int16'],
...@@ -293,7 +152,7 @@ class TestLoad(LoadTestBase): ...@@ -293,7 +152,7 @@ class TestLoad(LoadTestBase):
)), name_func=name_func) )), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly.""" """`sox_io_backend.load` can load mp3 format correctly."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=1, atol=5e-05)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[16000], [16000],
...@@ -303,7 +162,7 @@ class TestLoad(LoadTestBase): ...@@ -303,7 +162,7 @@ class TestLoad(LoadTestBase):
def test_mp3_large(self, sample_rate, num_channels, bit_rate): def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly.""" """`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours) self.assert_format("mp3", sample_rate, num_channels, compression=bit_rate, duration=two_hours, atol=5e-05)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[8000, 16000], [8000, 16000],
...@@ -312,7 +171,7 @@ class TestLoad(LoadTestBase): ...@@ -312,7 +171,7 @@ class TestLoad(LoadTestBase):
)), name_func=name_func) )), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly.""" """`sox_io_backend.load` can load flac format correctly."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) self.assert_format("flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[16000], [16000],
...@@ -322,7 +181,8 @@ class TestLoad(LoadTestBase): ...@@ -322,7 +181,8 @@ class TestLoad(LoadTestBase):
def test_flac_large(self, sample_rate, num_channels, compression_level): def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly.""" """`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, two_hours) self.assert_format(
"flac", sample_rate, num_channels, compression=compression_level, bit_depth=16, duration=two_hours)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[8000, 16000], [8000, 16000],
...@@ -331,7 +191,7 @@ class TestLoad(LoadTestBase): ...@@ -331,7 +191,7 @@ class TestLoad(LoadTestBase):
)), name_func=name_func) )), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly.""" """`sox_io_backend.load` can load vorbis format correctly."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) self.assert_format("vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[16000], [16000],
...@@ -341,7 +201,8 @@ class TestLoad(LoadTestBase): ...@@ -341,7 +201,8 @@ class TestLoad(LoadTestBase):
def test_vorbis_large(self, sample_rate, num_channels, quality_level): def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly.""" """`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours) self.assert_format(
"vorbis", sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=two_hours)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['96k'], ['96k'],
...@@ -366,7 +227,7 @@ class TestLoad(LoadTestBase): ...@@ -366,7 +227,7 @@ class TestLoad(LoadTestBase):
)), name_func=name_func) )), name_func=name_func)
def test_sphere(self, sample_rate, num_channels): def test_sphere(self, sample_rate, num_channels):
"""`sox_io_backend.load` can load sph format correctly.""" """`sox_io_backend.load` can load sph format correctly."""
self.assert_sphere(sample_rate, num_channels, duration=1) self.assert_format("sph", sample_rate, num_channels, bit_depth=32, duration=1)
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16'], ['float32', 'int32', 'int16'],
...@@ -375,12 +236,15 @@ class TestLoad(LoadTestBase): ...@@ -375,12 +236,15 @@ class TestLoad(LoadTestBase):
[False, True], [False, True],
)), name_func=name_func) )), name_func=name_func)
def test_amb(self, dtype, sample_rate, num_channels, normalize): def test_amb(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load sph format correctly.""" """`sox_io_backend.load` can load amb format correctly."""
self.assert_amb(dtype, sample_rate, num_channels, normalize, duration=1) bit_depth = sox_utils.get_bit_depth(dtype)
encoding = sox_utils.get_encoding(dtype)
self.assert_format(
"amb", sample_rate, num_channels, bit_depth=bit_depth, duration=1, encoding=encoding, normalize=normalize)
def test_amr_nb(self): def test_amr_nb(self):
"""`sox_io_backend.load` can load amr_nb format correctly.""" """`sox_io_backend.load` can load amr_nb format correctly."""
self.assert_amr_nb(duration=1) self.assert_format("amr-nb", sample_rate=8000, num_channels=1, bit_depth=32, duration=1)
@skipIfNoExec('sox') @skipIfNoExec('sox')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment