test_transforms.py 18.6 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
7
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
9
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
10
import unittest
11
import common_utils
David Pollack's avatar
David Pollack committed
12

13
14
15
16
17
18
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Vincent QB's avatar
Vincent QB committed
19
SKIP_LIBROSA_CONSISTENCY_TEST = True
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)


def _test_script_module(f, tensor, *args, **kwargs):

    py_method = f(*args, **kwargs)
    jit_method = torch.jit.script(py_method)

    py_out = py_method(tensor)
    jit_out = jit_method(tensor)

    assert torch.allclose(jit_out, py_out)

    if RUN_CUDA:

        tensor = tensor.to("cuda")

        py_method = py_method.cuda()
        jit_method = torch.jit.script(py_method)

        py_out = py_method(tensor)
        jit_out = jit_method(tensor)

        assert torch.allclose(jit_out, py_out)

Soumith Chintala's avatar
Soumith Chintala committed
46

David Pollack's avatar
David Pollack committed
47
48
class Tester(unittest.TestCase):

49
    # create a sinewave signal for testing
50
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
51
    freq = 440
52
    volume = .3
53
54
55
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
56
    # file for stereo stft test
57
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
58
59
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
60

61
62
63
64
65
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
66

67
68
69
70
    def test_scriptmodule_Spectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.Spectrogram, tensor)

David Pollack's avatar
David Pollack committed
71
72
73
74
    def test_mu_law_companding(self):

        quantization_channels = 256

75
76
77
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
78

79
80
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
81

82
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
83
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
84

85
86
87
88
89
90
91
92
    def test_scriptmodule_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
        _test_script_module(transforms.AmplitudeToDB, spec)

    def test_scriptmodule_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
        _test_script_module(transforms.MelScale, spec_f)

93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

107
108
109
110
    def test_scriptmodule_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MelSpectrogram, tensor)

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

130
    def test_mel2(self):
PCerles's avatar
PCerles committed
131
        top_db = 80.
132
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
133

134
135
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
136
        mel_transform = transforms.MelSpectrogram()
137
        # check defaults
138
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
139
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
140
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
141
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
142
        # check correctness of filterbank conversion matrix
143
144
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
145
        # check options
146
147
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
148
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
149
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
150
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
151
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
152
153
154
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
155
        # check on multi-channel audio
156
157
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
158
159
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
160
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
161
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
162
        # check filterbank matrix creation
163
164
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
165
166
167
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
168

169
170
171
172
    def test_scriptmodule_MFCC(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MFCC, tensor)

PCerles's avatar
PCerles committed
173
    def test_mfcc(self):
174
175
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
176
177
178
179

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
180
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
181
182
183
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
184
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
185
        self.assertTrue(torch_mfcc.dim() == 3)
186
187
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
188
        # check melkwargs are passed through
189
190
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
191
192
193
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
194
195
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
196
197

        # check norms work correctly
198
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
199
200
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
201
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
202
203

        norm_check = torch_mfcc.clone()
204
205
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
206
207
208

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

Vincent QB's avatar
Vincent QB committed
209
210
211
212
    @unittest.skipIf(
        SKIP_LIBROSA_CONSISTENCY_TEST or not IMPORT_LIBROSA or not IMPORT_SCIPY,
        'Librosa and scipy are not available, or consisency test disabled'
    )
PCerles's avatar
PCerles committed
213
    def test_librosa_consistency(self):
214
215
216
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
217
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
218
219

            # test core spectrogram
220
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
221
222
223
224
225
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

226
            out_torch = spect_transform(sound).squeeze().cpu()
227
228
229
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
230
231
232
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
233
234
235
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
236
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
237
            torch_mel = melspect_transform(sound).squeeze().cpu()
238

jamarshon's avatar
jamarshon committed
239
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
240
241

            # test s2db
242
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
243
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
244
245
246
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

247
            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
248
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
249
            db_librosa_tensor = torch.from_numpy(db_librosa)
250

jamarshon's avatar
jamarshon committed
251
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
252
253

            # test MFCC
254
255
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
275
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
276
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
277

jamarshon's avatar
jamarshon committed
278
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
310

jamarshon's avatar
jamarshon committed
311
312
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
313
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
314
315
316
317
318

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

319
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
320
321
322

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
323
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
324
325

        # we expect the upsampled signal to have twice as many samples
326
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
327
328
329

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
330
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
331
332

        # we expect the downsampled signal to have half as many samples
333
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
334

Vincent QB's avatar
Vincent QB committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

Vincent QB's avatar
Vincent QB committed
366
367
368
369
370
371
372
373
374
375
376
377
378
    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

379
380
381
382
383
384
385
386
    def test_scriptmodule_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawEncoding, tensor)

    def test_scriptmodule_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawDecoding, tensor)

Vincent QB's avatar
Vincent QB committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

413
414
415
416
417
418
419
420
421
422
423
424
    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

425
426
427
428
429
    def test_scriptmodule_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
430
        _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
431
432
433

    def test_scriptmodule_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
434
        _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
435
436
437

    def test_scriptmodule_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
438
        _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
439

Vincent QB's avatar
Vincent QB committed
440

David Pollack's avatar
David Pollack committed
441
442
if __name__ == '__main__':
    unittest.main()