test.py 11.3 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
David Pollack's avatar
David Pollack committed
2
import unittest
3
import common_utils
4
import torch
Soumith Chintala's avatar
Soumith Chintala committed
5
import torchaudio
David Pollack's avatar
David Pollack committed
6
import math
David Pollack's avatar
David Pollack committed
7
import os
8
import sys
David Pollack's avatar
David Pollack committed
9

Soumith Chintala's avatar
Soumith Chintala committed
10

Vincent QB's avatar
Vincent QB committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class AudioBackendScope:
    def __init__(self, backend):
        self.new_backend = backend
        self.previous_backend = torchaudio.get_audio_backend()

    def __enter__(self):
        torchaudio.set_audio_backend(self.new_backend)
        return self.new_backend

    def __exit__(self, type, value, traceback):
        backend = self.previous_backend
        torchaudio.set_audio_backend(backend)


David Pollack's avatar
David Pollack committed
25
class Test_LoadSave(unittest.TestCase):
26
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
27
28
    test_filepath = os.path.join(test_dirpath, "assets",
                                 "steam-train-whistle-daniel_simon.mp3")
Vincent QB's avatar
Vincent QB committed
29
30
    test_filepath_wav = os.path.join(test_dirpath, "assets",
                                     "steam-train-whistle-daniel_simon.wav")
Soumith Chintala's avatar
Soumith Chintala committed
31

32
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
David Pollack's avatar
David Pollack committed
33
    def test_1_save(self):
Vincent QB's avatar
Vincent QB committed
34
35
36
37
38
39
40
41
42
43
44
        for backend in ["sox"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_1_save(self.test_filepath, False)

        for backend in ["sox", "soundfile"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_1_save(self.test_filepath_wav, True)

    def _test_1_save(self, test_filepath, normalization):
David Pollack's avatar
David Pollack committed
45
        # load signal
Vincent QB's avatar
Vincent QB committed
46
        x, sr = torchaudio.load(test_filepath, normalization=normalization)
David Pollack's avatar
David Pollack committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60

        # check save
        new_filepath = os.path.join(self.test_dirpath, "test.wav")
        torchaudio.save(new_filepath, x, sr)
        self.assertTrue(os.path.isfile(new_filepath))
        os.unlink(new_filepath)

        # check automatic normalization
        x /= 1 << 31
        torchaudio.save(new_filepath, x, sr)
        self.assertTrue(os.path.isfile(new_filepath))
        os.unlink(new_filepath)

        # test save 1d tensor
David Pollack's avatar
David Pollack committed
61
        x = x[0, :]  # get mono signal
Soumith Chintala's avatar
Soumith Chintala committed
62
        x.squeeze_()  # remove channel dim
David Pollack's avatar
David Pollack committed
63
64
65
66
67
68
        torchaudio.save(new_filepath, x, sr)
        self.assertTrue(os.path.isfile(new_filepath))
        os.unlink(new_filepath)

        # don't allow invalid sizes as inputs
        with self.assertRaises(ValueError):
David Pollack's avatar
David Pollack committed
69
            x.unsqueeze_(1)  # L x C not C x L
David Pollack's avatar
David Pollack committed
70
71
72
73
74
            torchaudio.save(new_filepath, x, sr)

        with self.assertRaises(ValueError):
            x.squeeze_()
            x.unsqueeze_(1)
Soumith Chintala's avatar
Soumith Chintala committed
75
            x.unsqueeze_(0)  # 1 x L x 1
David Pollack's avatar
David Pollack committed
76
77
78
79
            torchaudio.save(new_filepath, x, sr)

        # don't save to folders that don't exist
        with self.assertRaises(OSError):
80
81
            new_filepath = os.path.join(self.test_dirpath, "no-path",
                                        "test.wav")
David Pollack's avatar
David Pollack committed
82
            torchaudio.save(new_filepath, x, sr)
Soumith Chintala's avatar
Soumith Chintala committed
83

84
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
Vincent QB's avatar
Vincent QB committed
85
86
87
88
89
90
91
92
    def test_1_save_sine(self):
        for backend in ["sox", "soundfile"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_1_save_sine()

    def _test_1_save_sine(self):

93
        # save created file
94
95
        sinewave_filepath = os.path.join(self.test_dirpath, "assets",
                                         "sinewave.wav")
96
97
98
99
        sr = 16000
        freq = 440
        volume = 0.3

100
        y = (torch.cos(
101
            2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
David Pollack's avatar
David Pollack committed
102
        y.unsqueeze_(0)
103
        # y is between -1 and 1, so must scale
David Pollack's avatar
David Pollack committed
104
        y = (y * volume * (2**31)).long()
105
106
        torchaudio.save(sinewave_filepath, y, sr)
        self.assertTrue(os.path.isfile(sinewave_filepath))
107

108
        # test precision
David Pollack's avatar
David Pollack committed
109
        new_precision = 32
110
        new_filepath = os.path.join(self.test_dirpath, "test.wav")
David Pollack's avatar
David Pollack committed
111
112
113
114
115
        si, ei = torchaudio.info(sinewave_filepath)
        torchaudio.save(new_filepath, y, sr, new_precision)
        si32, ei32 = torchaudio.info(new_filepath)
        self.assertEqual(si.precision, 16)
        self.assertEqual(si32.precision, new_precision)
116
117
        os.unlink(new_filepath)

118
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
David Pollack's avatar
David Pollack committed
119
    def test_2_load(self):
Vincent QB's avatar
Vincent QB committed
120
121
122
123
124
125
126
127
128
129
130
        for backend in ["sox"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_2_load(self.test_filepath, 278756)

        for backend in ["sox", "soundfile"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_2_load(self.test_filepath_wav, 276858)

    def _test_2_load(self, test_filepath, length):
David Pollack's avatar
David Pollack committed
131
        # check normal loading
Vincent QB's avatar
Vincent QB committed
132
        x, sr = torchaudio.load(test_filepath)
David Pollack's avatar
David Pollack committed
133
        self.assertEqual(sr, 44100)
Vincent QB's avatar
Vincent QB committed
134
        self.assertEqual(x.size(), (2, length))
David Pollack's avatar
David Pollack committed
135
136
137

        # check offset
        offset = 15
Vincent QB's avatar
Vincent QB committed
138
139
        x, _ = torchaudio.load(test_filepath)
        x_offset, _ = torchaudio.load(test_filepath, offset=offset)
140
        self.assertTrue(x[:, offset:].allclose(x_offset))
David Pollack's avatar
David Pollack committed
141
142
143

        # check number of frames
        n = 201
Vincent QB's avatar
Vincent QB committed
144
        x, _ = torchaudio.load(test_filepath, num_frames=n)
David Pollack's avatar
David Pollack committed
145
146
147
        self.assertTrue(x.size(), (2, n))

        # check channels first
Vincent QB's avatar
Vincent QB committed
148
149
        x, _ = torchaudio.load(test_filepath, channels_first=False)
        self.assertEqual(x.size(), (length, 2))
David Pollack's avatar
David Pollack committed
150
151
152
153
154
155
156
157
158
159

        # check raising errors
        with self.assertRaises(OSError):
            torchaudio.load("file-does-not-exist.mp3")

        with self.assertRaises(OSError):
            tdir = os.path.join(
                os.path.dirname(self.test_dirpath), "torchaudio")
            torchaudio.load(tdir)

160
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
Vincent QB's avatar
Vincent QB committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    def test_2_load_nonormalization(self):
        for backend in ["sox"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_2_load_nonormalization(self.test_filepath, 278756)

    def _test_2_load_nonormalization(self, test_filepath, length):

        # check no normalizing
        x, _ = torchaudio.load(test_filepath, normalization=False)
        self.assertTrue(x.min() <= -1.0)
        self.assertTrue(x.max() >= 1.0)

        # check different input tensor type
        x, _ = torchaudio.load(test_filepath, torch.LongTensor(), normalization=False)
        self.assertTrue(isinstance(x, torch.LongTensor))

178
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
David Pollack's avatar
David Pollack committed
179
    def test_3_load_and_save_is_identity(self):
Vincent QB's avatar
Vincent QB committed
180
181
182
183
184
185
        for backend in ["sox", "soundfile"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_3_load_and_save_is_identity()

    def _test_3_load_and_save_is_identity(self):
186
187
188
189
190
191
192
193
194
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        tensor, sample_rate = torchaudio.load(input_path)
        output_path = os.path.join(self.test_dirpath, 'test.wav')
        torchaudio.save(output_path, tensor, sample_rate)
        tensor2, sample_rate2 = torchaudio.load(output_path)
        self.assertTrue(tensor.allclose(tensor2))
        self.assertEqual(sample_rate, sample_rate2)
        os.unlink(output_path)

195
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
Vincent QB's avatar
Vincent QB committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def test_3_load_and_save_is_identity_across_backend(self):
        with self.subTest():
            self._test_3_load_and_save_is_identity_across_backend("sox", "soundfile")
        with self.subTest():
            self._test_3_load_and_save_is_identity_across_backend("soundfile", "sox")

    def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):
        with AudioBackendScope(backend1):

            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            tensor1, sample_rate1 = torchaudio.load(input_path)

            output_path = os.path.join(self.test_dirpath, 'test.wav')
            torchaudio.save(output_path, tensor1, sample_rate1)

        with AudioBackendScope(backend2):
            tensor2, sample_rate2 = torchaudio.load(output_path)

        self.assertTrue(tensor1.allclose(tensor2))
        self.assertEqual(sample_rate1, sample_rate2)
        os.unlink(output_path)

218
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
David Pollack's avatar
David Pollack committed
219
    def test_4_load_partial(self):
Vincent QB's avatar
Vincent QB committed
220
221
222
223
224
225
        for backend in ["sox"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_4_load_partial()

    def _test_4_load_partial(self):
David Pollack's avatar
David Pollack committed
226
227
        num_frames = 101
        offset = 201
228
229
230
231
        # load entire mono sinewave wav file, load a partial copy and then compare
        input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        x_sine_full, sr_sine = torchaudio.load(input_sine_path)
        x_sine_part, _ = torchaudio.load(input_sine_path, num_frames=num_frames, offset=offset)
232
        l1_error = x_sine_full[:, offset:(num_frames + offset)].sub(x_sine_part).abs().sum().item()
233
        # test for the correct number of samples and that the correct portion was loaded
David Pollack's avatar
David Pollack committed
234
        self.assertEqual(x_sine_part.size(1), num_frames)
235
236
237
238
239
240
241
        self.assertEqual(l1_error, 0.)
        # create a two channel version of this wavefile
        x_2ch_sine = x_sine_full.repeat(1, 2)
        out_2ch_sine_path = os.path.join(self.test_dirpath, 'assets', '2ch_sinewave.wav')
        torchaudio.save(out_2ch_sine_path, x_2ch_sine, sr_sine)
        x_2ch_sine_load, _ = torchaudio.load(out_2ch_sine_path, num_frames=num_frames, offset=offset)
        os.unlink(out_2ch_sine_path)
David Pollack's avatar
David Pollack committed
242
        l1_error = x_2ch_sine_load.sub(x_2ch_sine[:, offset:(offset + num_frames)]).abs().sum().item()
243
244
245
246
247
        self.assertEqual(l1_error, 0.)

        # test with two channel mp3
        x_2ch_full, sr_2ch = torchaudio.load(self.test_filepath, normalization=True)
        x_2ch_part, _ = torchaudio.load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset)
248
        l1_error = x_2ch_full[:, offset:(offset + num_frames)].sub(x_2ch_part).abs().sum().item()
David Pollack's avatar
David Pollack committed
249
        self.assertEqual(x_2ch_part.size(1), num_frames)
250
251
252
253
254
        self.assertEqual(l1_error, 0.)

        # check behavior if number of samples would exceed file length
        offset_ns = 300
        x_ns, _ = torchaudio.load(input_sine_path, num_frames=100000, offset=offset_ns)
David Pollack's avatar
David Pollack committed
255
        self.assertEqual(x_ns.size(1), x_sine_full.size(1) - offset_ns)
256
257
258
259
260

        # check when offset is beyond the end of the file
        with self.assertRaises(RuntimeError):
            torchaudio.load(input_sine_path, offset=100000)

261
    @unittest.skipIf(sys.version_info < (3, 4), "subTest unavailable for this Python version")
David Pollack's avatar
David Pollack committed
262
    def test_5_get_info(self):
Vincent QB's avatar
Vincent QB committed
263
264
265
266
267
268
        for backend in ["sox", "soundfile"]:
            with self.subTest():
                with AudioBackendScope(backend):
                    self._test_5_get_info()

    def _test_5_get_info(self):
269
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
David Pollack's avatar
David Pollack committed
270
271
272
273
274
275
        channels, samples, rate, precision = (1, 64000, 16000, 16)
        si, ei = torchaudio.info(input_path)
        self.assertEqual(si.channels, channels)
        self.assertEqual(si.length, samples)
        self.assertEqual(si.rate, rate)
        self.assertEqual(ei.bits_per_sample, precision)
Soumith Chintala's avatar
Soumith Chintala committed
276

David Pollack's avatar
David Pollack committed
277
278
if __name__ == '__main__':
    unittest.main()