test_transforms.py 18.5 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
import torchaudio.augmentations as A
David Pollack's avatar
David Pollack committed
8
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
9
10
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
11
import unittest
12
import common_utils
David Pollack's avatar
David Pollack committed
13

14
15
16
17
18
19
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)


def _test_script_module(f, tensor, *args, **kwargs):

    py_method = f(*args, **kwargs)
    jit_method = torch.jit.script(py_method)

    py_out = py_method(tensor)
    jit_out = jit_method(tensor)

    assert torch.allclose(jit_out, py_out)

    if RUN_CUDA:

        tensor = tensor.to("cuda")

        py_method = py_method.cuda()
        jit_method = torch.jit.script(py_method)

        py_out = py_method(tensor)
        jit_out = jit_method(tensor)

        assert torch.allclose(jit_out, py_out)

Soumith Chintala's avatar
Soumith Chintala committed
46

David Pollack's avatar
David Pollack committed
47
48
class Tester(unittest.TestCase):

49
    # create a sinewave signal for testing
50
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
51
    freq = 440
52
    volume = .3
53
54
55
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
56
    # file for stereo stft test
57
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
58
59
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
60

61
62
63
64
65
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
66

67
68
69
70
    def test_scriptmodule_Spectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.Spectrogram, tensor)

David Pollack's avatar
David Pollack committed
71
72
73
74
    def test_mu_law_companding(self):

        quantization_channels = 256

75
76
77
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
78

79
80
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
81

82
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
83
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
84

85
86
87
88
89
90
91
92
    def test_scriptmodule_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
        _test_script_module(transforms.AmplitudeToDB, spec)

    def test_scriptmodule_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
        _test_script_module(transforms.MelScale, spec_f)

93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

107
108
109
110
    def test_scriptmodule_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MelSpectrogram, tensor)

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

130
    def test_mel2(self):
PCerles's avatar
PCerles committed
131
        top_db = 80.
132
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
133

134
135
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
136
        mel_transform = transforms.MelSpectrogram()
137
        # check defaults
138
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
139
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
140
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
141
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
142
        # check correctness of filterbank conversion matrix
143
144
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
145
        # check options
146
147
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
148
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
149
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
150
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
151
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
152
153
154
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
155
        # check on multi-channel audio
156
157
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
158
159
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
160
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
161
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
162
        # check filterbank matrix creation
163
164
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
165
166
167
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
168

169
170
171
172
    def test_scriptmodule_MFCC(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MFCC, tensor)

PCerles's avatar
PCerles committed
173
    def test_mfcc(self):
174
175
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
176
177
178
179

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
180
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
181
182
183
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
184
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
185
        self.assertTrue(torch_mfcc.dim() == 3)
186
187
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
188
        # check melkwargs are passed through
189
190
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
191
192
193
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
194
195
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
196
197

        # check norms work correctly
198
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
199
200
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
201
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
202
203

        norm_check = torch_mfcc.clone()
204
205
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
206
207
208

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

209
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
210
    def test_librosa_consistency(self):
211
212
213
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
214
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
215
216

            # test core spectrogram
217
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
218
219
220
221
222
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

223
            out_torch = spect_transform(sound).squeeze().cpu()
224
225
226
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
227
228
229
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
230
231
232
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
233
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
234
            torch_mel = melspect_transform(sound).squeeze().cpu()
235

jamarshon's avatar
jamarshon committed
236
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
237
238

            # test s2db
239
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
240
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
241
242
243
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

244
            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
245
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
246
            db_librosa_tensor = torch.from_numpy(db_librosa)
247

jamarshon's avatar
jamarshon committed
248
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
249
250

            # test MFCC
251
252
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
272
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
273
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
274

jamarshon's avatar
jamarshon committed
275
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
307

jamarshon's avatar
jamarshon committed
308
309
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
310
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
311
312
313
314
315

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

316
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
317
318
319

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
320
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
321
322

        # we expect the upsampled signal to have twice as many samples
323
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
324
325
326

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
327
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
328
329

        # we expect the downsampled signal to have half as many samples
330
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
331

Vincent QB's avatar
Vincent QB committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

Vincent QB's avatar
Vincent QB committed
363
364
365
366
367
368
369
370
371
372
373
374
375
    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

376
377
378
379
380
381
382
383
    def test_scriptmodule_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawEncoding, tensor)

    def test_scriptmodule_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawDecoding, tensor)

Vincent QB's avatar
Vincent QB committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

410
411
412
413
414
415
416
417
418
419
420
421
    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    def test_scriptmodule_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
        _test_script_module(A.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)

    def test_scriptmodule_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
        _test_script_module(A.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)

    def test_scriptmodule_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
        _test_script_module(A.TimeMasking, tensor, time_mask_param=30, iid_masks=False)

Vincent QB's avatar
Vincent QB committed
437

David Pollack's avatar
David Pollack committed
438
439
if __name__ == '__main__':
    unittest.main()