"vscode:/vscode.git/clone" did not exist on "132dad874d2e44592d03a112e4b7d63b153e8346"
test_transforms.py 30 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
7
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
9
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
10
import unittest
11
import common_utils
David Pollack's avatar
David Pollack committed
12

13
14
15
16
17
18
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)


def _test_script_module(f, tensor, *args, **kwargs):

    py_method = f(*args, **kwargs)
    jit_method = torch.jit.script(py_method)

    py_out = py_method(tensor)
    jit_out = jit_method(tensor)

    assert torch.allclose(jit_out, py_out)

    if RUN_CUDA:

        tensor = tensor.to("cuda")

        py_method = py_method.cuda()
        jit_method = torch.jit.script(py_method)

        py_out = py_method(tensor)
        jit_out = jit_method(tensor)

        assert torch.allclose(jit_out, py_out)

Soumith Chintala's avatar
Soumith Chintala committed
45

David Pollack's avatar
David Pollack committed
46
47
class Tester(unittest.TestCase):

48
    # create a sinewave signal for testing
49
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
50
    freq = 440
51
    volume = .3
52
53
54
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
55
    # file for stereo stft test
56
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
57
58
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
59

60
61
62
63
64
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
65

66
67
68
69
    def test_scriptmodule_Spectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.Spectrogram, tensor)

70
71
72
73
    def test_scriptmodule_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
        _test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False)

David Pollack's avatar
David Pollack committed
74
75
76
77
    def test_mu_law_companding(self):

        quantization_channels = 256

78
79
80
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
81

82
83
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
84

85
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
86
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
87

88
89
90
91
    def test_scriptmodule_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
        _test_script_module(transforms.AmplitudeToDB, spec)

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def test_batch_AmplitudeToDB(self):
        spec = torch.rand((6, 201))

        # Single then transform then batch
        expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))

115
116
117
118
    def test_scriptmodule_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
        _test_script_module(transforms.MelScale, spec_f)

119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

133
134
135
136
    def test_scriptmodule_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MelSpectrogram, tensor)

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

156
    def test_mel2(self):
PCerles's avatar
PCerles committed
157
        top_db = 80.
158
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
159

160
161
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
162
        mel_transform = transforms.MelSpectrogram()
163
        # check defaults
164
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
165
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
166
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
167
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
168
        # check correctness of filterbank conversion matrix
169
170
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
171
        # check options
172
173
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
174
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
175
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
176
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
177
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
178
179
180
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
181
        # check on multi-channel audio
182
183
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
184
185
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
186
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
187
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
188
        # check filterbank matrix creation
189
190
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
191
192
193
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
194

195
196
197
198
    def test_scriptmodule_MFCC(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MFCC, tensor)

PCerles's avatar
PCerles committed
199
    def test_mfcc(self):
200
201
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
202
203
204
205

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
206
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
207
208
209
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
210
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
211
        self.assertTrue(torch_mfcc.dim() == 3)
212
213
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
214
        # check melkwargs are passed through
215
216
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
217
218
219
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
220
221
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
222
223

        # check norms work correctly
224
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
225
226
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
227
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
228
229

        norm_check = torch_mfcc.clone()
230
231
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
232
233
234

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

235
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
236
    def test_librosa_consistency(self):
237
238
239
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
240
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
241
242

            # test core spectrogram
243
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power)
244
245
246
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
247
                                                                power=power)
248

249
            out_torch = spect_transform(sound).squeeze().cpu()
250
251
252
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
253
254
255
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
256
257
258
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
259
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
260
            torch_mel = melspect_transform(sound).squeeze().cpu()
261

jamarshon's avatar
jamarshon committed
262
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
263
264

            # test s2db
265
266
267
268
269
270
271
272
273
274
275
276
277
            power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
            power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
            power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3))

            mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
            mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
            mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
            self.assertTrue(
                torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
            )

            power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
278
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
279
            db_librosa_tensor = torch.from_numpy(db_librosa)
280
281
282
            self.assertTrue(
                torch.allclose(power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
            )
283
284

            # test MFCC
285
286
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
306
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
307
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
308

jamarshon's avatar
jamarshon committed
309
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

338
339
340
341
342
343
344
345
346
        kwargs4 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 3.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

347
348
        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
349
350
351
        # NOTE Test passes offline, but fails on CircleCI, see #372.
        # _test_librosa_consistency_helper(**kwargs3)
        _test_librosa_consistency_helper(**kwargs4)
PCerles's avatar
PCerles committed
352

Oktai Tatanov's avatar
Oktai Tatanov committed
353
354
    def test_scriptmodule_Resample(self):
        tensor = torch.rand((2, 1000))
355
356
        sample_rate = 100.
        sample_rate_2 = 50.
Oktai Tatanov's avatar
Oktai Tatanov committed
357

358
        _test_script_module(transforms.Resample, tensor, sample_rate, sample_rate_2)
Oktai Tatanov's avatar
Oktai Tatanov committed
359

Vincent QB's avatar
Vincent QB committed
360
361
362
363
364
365
366
367
368
369
370
371
    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Resample()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

372
373
374
375
    def test_scriptmodule_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
        _test_script_module(transforms.ComplexNorm, tensor)

jamarshon's avatar
jamarshon committed
376
377
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
378
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
379
380
381
382
383

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

384
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
385
386
387

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
388
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
389
390

        # we expect the upsampled signal to have twice as many samples
391
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
392
393
394

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
395
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
396
397

        # we expect the downsampled signal to have half as many samples
398
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
399

Vincent QB's avatar
Vincent QB committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

Vincent QB's avatar
Vincent QB committed
431
432
433
434
435
436
437
438
439
440
441
442
443
    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

moto's avatar
moto committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def test_batch_InverseMelScale(self):
        n_fft = 8
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32) ** 2

        # Single then transform then batch
        expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
        self.assertTrue(torch.allclose(computed, expected, atol=1.0))

Vincent QB's avatar
Vincent QB committed
463
464
465
466
467
468
469
470
471
472
473
474
475
    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

476
477
478
479
480
481
482
483
    def test_scriptmodule_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawEncoding, tensor)

    def test_scriptmodule_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawDecoding, tensor)

Vincent QB's avatar
Vincent QB committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

510
511
512
513
514
515
516
517
518
519
520
521
    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
    def test_batch_melspectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_mfcc(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

546
547
548
549
550
    def test_scriptmodule_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
551
        _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
552

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    def test_batch_TimeStretch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

Tomás Osório's avatar
Tomás Osório committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    def test_batch_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_scriptmodule_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        _test_script_module(transforms.Fade, waveform, fade_in_len, fade_out_len)

604
605
    def test_scriptmodule_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
606
        _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
607
608
609

    def test_scriptmodule_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
610
        _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
611

Tomás Osório's avatar
Tomás Osório committed
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
    def test_scriptmodule_Vol(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        _test_script_module(transforms.Vol, waveform, 1.1)

    def test_batch_Vol(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
629

moto's avatar
moto committed
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
class TestLibrosaConsistency(unittest.TestCase):
    test_dirpath = None
    test_dir = None

    @classmethod
    def setUpClass(cls):
        cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir()

    def _to_librosa(self, sound):
        return sound.cpu().numpy().squeeze()

    def _get_sample_data(self, *asset_paths, **kwargs):
        file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths)

        sound, sample_rate = torchaudio.load(file_path, **kwargs)
        return sound.mean(dim=0, keepdim=True), sample_rate

    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
    def test_MelScale(self):
        """MelScale transform is comparable to that of librosa"""
        n_fft = 2048
        n_mels = 256
        hop_length = n_fft // 4

        # Prepare spectrogram input. We use torchaudio to compute one.
        sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3')
        spec_ta = F.spectrogram(
            sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
            hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
        spec_lr = spec_ta.cpu().numpy().squeeze()
        # Perform MelScale with torchaudio and librosa
        melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
        melspec_lr = librosa.feature.melspectrogram(
            S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
            win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
        # Note: Using relaxed rtol instead of atol
        assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)

    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
    def test_InverseMelScale(self):
        """InverseMelScale transform is comparable to that of librosa"""
        n_fft = 2048
        n_mels = 256
        n_stft = n_fft // 2 + 1
        hop_length = n_fft // 4

        # Prepare mel spectrogram input. We use torchaudio to compute one.
        sound, sample_rate = self._get_sample_data(
            'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
        spec_orig = F.spectrogram(
            sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
            hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
        melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
        melspec_lr = melspec_ta.cpu().numpy().squeeze()
        # Perform InverseMelScale with torch audio and librosa
        spec_ta = transforms.InverseMelScale(
            n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
        spec_lr = librosa.feature.inverse.mel_to_stft(
            melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
        spec_lr = torch.from_numpy(spec_lr[None, ...])

        # Align dimensions
        # librosa does not return power spectrogram while torchaudio returns power spectrogram
        spec_orig = spec_orig.sqrt()
        spec_ta = spec_ta.sqrt()

        threshold = 2.0
        # This threshold was choosen empirically, based on the following observation
        #
        # torch.dist(spec_lr, spec_ta, p=float('inf'))
        # >>> tensor(1.9666)
        #
        # The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise.
        # This is because they use different approximation algorithms and resulting values can live
        # in different magnitude. (although most of them are very close)
        # See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
        # See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
        # distance over frequencies.
        assert torch.allclose(spec_ta, spec_lr, atol=threshold)

        threshold = 1700.0
        # This threshold was choosen empirically, based on the following observations
        #
        # torch.dist(spec_orig, spec_ta, p=1)
        # >>> tensor(1644.3516)
        # torch.dist(spec_orig, spec_lr, p=1)
        # >>> tensor(1420.7103)
        # torch.dist(spec_lr, spec_ta, p=1)
        # >>> tensor(943.2759)
        assert torch.dist(spec_orig, spec_ta, p=1) < threshold


David Pollack's avatar
David Pollack committed
722
723
if __name__ == '__main__':
    unittest.main()