test_load.py 9.83 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
import itertools

from torchaudio.backend import sox_io_backend
from parameterized import parameterized

from ..common_utils import (
    TempDirMixin,
    PytorchTestCase,
    skipIfNoExec,
    skipIfNoExtension,
moto's avatar
moto committed
11
    get_asset_path,
moto's avatar
moto committed
12
13
14
    get_wav_data,
    load_wav,
    save_wav,
moto's avatar
moto committed
15
16
17
18
    sox_utils,
)
from .common import (
    name_func,
moto's avatar
moto committed
19
20
21
22
23
24
25
26
27
)


class LoadTestBase(TempDirMixin, PytorchTestCase):
    def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
        """`sox_io_backend.load` can load wav format correctly.

        Wav data loaded with sox_io backend should match those with scipy
        """
28
        path = self.get_temp_path('reference.wav')
moto's avatar
moto committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
        save_wav(path, data, sample_rate)
        expected = load_wav(path, normalize=normalize)[0]
        data, sr = sox_io_backend.load(path, normalize=normalize)
        assert sr == sample_rate
        self.assertEqual(data, expected)

    def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
        """`sox_io_backend.load` can load mp3 format.

        mp3 encoding introduces delay and boundary effects so
        we create reference wav file from mp3

         x
         |
         | 1. Generate mp3 with Sox
         |
         v    2. Convert to wav with Sox
        mp3 ------------------------------> wav
         |                                   |
         | 3. Load with torchaudio           | 4. Load with scipy
         |                                   |
         v                                   v
        tensor ----------> x <----------- tensor
                       5. Compare

        Underlying assumptions are;
        i. Conversion of mp3 to wav with Sox preserves data.
        ii. Loading wav file with scipy is correct.

        By combining i & ii, step 2. and 4. allows to load reference mp3 data
        without using torchaudio
        """
62
63
        path = self.get_temp_path('1.original.mp3')
        ref_path = self.get_temp_path('2.reference.wav')
moto's avatar
moto committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

        # 1. Generate mp3 with sox
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=bit_rate, duration=duration)
        # 2. Convert to wav with sox
        sox_utils.convert_audio_file(path, ref_path)
        # 3. Load mp3 with torchaudio
        data, sr = sox_io_backend.load(path)
        # 4. Load wav with scipy
        data_ref = load_wav(ref_path)[0]
        # 5. Compare
        assert sr == sample_rate
        self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)

    def assert_flac(self, sample_rate, num_channels, compression_level, duration):
        """`sox_io_backend.load` can load flac format.

        This test takes the same strategy as mp3 to compare the result
        """
84
85
        path = self.get_temp_path('1.original.flac')
        ref_path = self.get_temp_path('2.reference.wav')
moto's avatar
moto committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

        # 1. Generate flac with sox
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=compression_level, bit_depth=16, duration=duration)
        # 2. Convert to wav with sox
        sox_utils.convert_audio_file(path, ref_path)
        # 3. Load flac with torchaudio
        data, sr = sox_io_backend.load(path)
        # 4. Load wav with scipy
        data_ref = load_wav(ref_path)[0]
        # 5. Compare
        assert sr == sample_rate
        self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)

    def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
        """`sox_io_backend.load` can load vorbis format.

        This test takes the same strategy as mp3 to compare the result
        """
106
107
        path = self.get_temp_path('1.original.vorbis')
        ref_path = self.get_temp_path('2.reference.wav')
moto's avatar
moto committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

        # 1. Generate vorbis with sox
        sox_utils.gen_audio_file(
            path, sample_rate, num_channels,
            compression=quality_level, bit_depth=16, duration=duration)
        # 2. Convert to wav with sox
        sox_utils.convert_audio_file(path, ref_path)
        # 3. Load vorbis with torchaudio
        data, sr = sox_io_backend.load(path)
        # 4. Load wav with scipy
        data_ref = load_wav(ref_path)[0]
        # 5. Compare
        assert sr == sample_rate
        self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)


@skipIfNoExec('sox')
@skipIfNoExtension
class TestLoad(LoadTestBase):
    """Test the correctness of `sox_io_backend.load` for various formats"""
    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [8000, 16000],
        [1, 2],
        [False, True],
moto's avatar
moto committed
133
    )), name_func=name_func)
moto's avatar
moto committed
134
135
136
137
138
139
140
141
142
    def test_wav(self, dtype, sample_rate, num_channels, normalize):
        """`sox_io_backend.load` can load wav format correctly."""
        self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

    @parameterized.expand(list(itertools.product(
        ['int16'],
        [16000],
        [2],
        [False],
moto's avatar
moto committed
143
    )), name_func=name_func)
moto's avatar
moto committed
144
145
146
147
148
149
150
151
    def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
        """`sox_io_backend.load` can load large wav file correctly."""
        two_hours = 2 * 60 * 60
        self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)

    @parameterized.expand(list(itertools.product(
        ['float32', 'int32', 'int16', 'uint8'],
        [4, 8, 16, 32],
moto's avatar
moto committed
152
    )), name_func=name_func)
moto's avatar
moto committed
153
154
155
156
157
158
159
160
161
162
    def test_multiple_channels(self, dtype, num_channels):
        """`sox_io_backend.load` can load wav file with more than 2 channels."""
        sample_rate = 8000
        normalize = False
        self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)

    @parameterized.expand(list(itertools.product(
        [8000, 16000, 44100],
        [1, 2],
        [96, 128, 160, 192, 224, 256, 320],
moto's avatar
moto committed
163
    )), name_func=name_func)
moto's avatar
moto committed
164
165
166
167
168
169
170
171
    def test_mp3(self, sample_rate, num_channels, bit_rate):
        """`sox_io_backend.load` can load mp3 format correctly."""
        self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)

    @parameterized.expand(list(itertools.product(
        [16000],
        [2],
        [128],
moto's avatar
moto committed
172
    )), name_func=name_func)
moto's avatar
moto committed
173
174
175
176
177
178
179
180
181
    def test_mp3_large(self, sample_rate, num_channels, bit_rate):
        """`sox_io_backend.load` can load large mp3 file correctly."""
        two_hours = 2 * 60 * 60
        self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours)

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        list(range(9)),
moto's avatar
moto committed
182
    )), name_func=name_func)
moto's avatar
moto committed
183
184
185
186
187
188
189
190
    def test_flac(self, sample_rate, num_channels, compression_level):
        """`sox_io_backend.load` can load flac format correctly."""
        self.assert_flac(sample_rate, num_channels, compression_level, duration=1)

    @parameterized.expand(list(itertools.product(
        [16000],
        [2],
        [0],
moto's avatar
moto committed
191
    )), name_func=name_func)
moto's avatar
moto committed
192
193
194
195
196
197
198
199
200
    def test_flac_large(self, sample_rate, num_channels, compression_level):
        """`sox_io_backend.load` can load large flac file correctly."""
        two_hours = 2 * 60 * 60
        self.assert_flac(sample_rate, num_channels, compression_level, two_hours)

    @parameterized.expand(list(itertools.product(
        [8000, 16000],
        [1, 2],
        [-1, 0, 1, 2, 3, 3.6, 5, 10],
moto's avatar
moto committed
201
    )), name_func=name_func)
moto's avatar
moto committed
202
203
204
205
206
207
208
209
    def test_vorbis(self, sample_rate, num_channels, quality_level):
        """`sox_io_backend.load` can load vorbis format correctly."""
        self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)

    @parameterized.expand(list(itertools.product(
        [16000],
        [2],
        [10],
moto's avatar
moto committed
210
    )), name_func=name_func)
moto's avatar
moto committed
211
212
213
214
215
    def test_vorbis_large(self, sample_rate, num_channels, quality_level):
        """`sox_io_backend.load` can load large vorbis file correctly."""
        two_hours = 2 * 60 * 60
        self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)

moto's avatar
moto committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    @parameterized.expand(list(itertools.product(
        ['96k'],
        [1, 2],
        [0, 5, 10],
    )), name_func=name_func)
    def test_opus(self, bitrate, num_channels, compression_level):
        """`sox_io_backend.load` can load opus file correctly."""
        ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
        wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav')
        sox_utils.convert_audio_file(ops_path, wav_path)

        expected, sample_rate = load_wav(wav_path)
        found, sr = sox_io_backend.load(ops_path)

        assert sample_rate == sr
        self.assertEqual(expected, found)

moto's avatar
moto committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

@skipIfNoExec('sox')
@skipIfNoExtension
class TestLoadParams(TempDirMixin, PytorchTestCase):
    """Test the correctness of frame parameters of `sox_io_backend.load`"""
    original = None
    path = None

    def setUp(self):
        super().setUp()
        sample_rate = 8000
        self.original = get_wav_data('float32', num_channels=2)
        self.path = self.get_temp_path('test.wave')
        save_wav(self.path, self.original, sample_rate)

    @parameterized.expand(list(itertools.product(
        [0, 1, 10, 100, 1000],
        [-1, 1, 10, 100, 1000],
moto's avatar
moto committed
251
    )), name_func=name_func)
moto's avatar
moto committed
252
253
254
255
256
257
    def test_frame(self, frame_offset, num_frames):
        """num_frames and frame_offset correctly specify the region of data"""
        found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
        frame_end = None if num_frames == -1 else frame_offset + num_frames
        self.assertEqual(found, self.original[:, frame_offset:frame_end])

moto's avatar
moto committed
258
    @parameterized.expand([(True, ), (False, )], name_func=name_func)
moto's avatar
moto committed
259
260
261
262
263
    def test_channels_first(self, channels_first):
        """channels_first swaps axes"""
        found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
        expected = self.original if channels_first else self.original.transpose(1, 0)
        self.assertEqual(found, expected)