test_roundtrip.py 1.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import itertools

from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from ..common_utils import (
    TempDirMixin,
    PytorchTestCase,
    skipIfNoExec,
    skipIfNoExtension,
moto's avatar
moto committed
11
    get_wav_data,
12
13
)
from .common import (
moto's avatar
moto committed
14
    name_func,
15
16
17
18
19
20
21
22
23
24
25
)


@skipIfNoExec('sox')
@skipIfNoExtension
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
    """save/load round trip should not degrade data for lossless formats"""
    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
moto's avatar
moto committed
26
    )), name_func=name_func)
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    def test_wav(self, dtype, sample_rate, num_channels):
        """save/load round trip should not degrade data for wav formats"""
        original = get_wav_data(dtype, num_channels, normalize=False)
        data = original
        for i in range(10):
            path = self.get_temp_path(f'{i}.wav')
            sox_io_backend.save(path, data, sample_rate)
            data, sr = sox_io_backend.load(path, normalize=False)
            assert sr == sample_rate
            self.assertEqual(original, data)

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        list(range(9)),
moto's avatar
moto committed
42
    )), name_func=name_func)
43
44
45
46
47
48
49
50
51
52
    def test_flac(self, sample_rate, num_channels, compression_level):
        """save/load round trip should not degrade data for flac formats"""
        original = get_wav_data('float32', num_channels)
        data = original
        for i in range(10):
            path = self.get_temp_path(f'{i}.flac')
            sox_io_backend.save(path, data, sample_rate, compression=compression_level)
            data, sr = sox_io_backend.load(path)
            assert sr == sample_rate
            self.assertEqual(original, data)