test_transforms.py 30.3 KB
Newer Older
1
import math
2
import os
3

David Pollack's avatar
David Pollack committed
4
5
6
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
7
8
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
9
import unittest
10
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
David Pollack's avatar
David Pollack committed
11

12
13
14
15
16
17
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
RUN_CUDA = torch.cuda.is_available()
print("Run test with cuda:", RUN_CUDA)


def _test_script_module(f, tensor, *args, **kwargs):

    py_method = f(*args, **kwargs)
    jit_method = torch.jit.script(py_method)

    py_out = py_method(tensor)
    jit_out = jit_method(tensor)

    assert torch.allclose(jit_out, py_out)

    if RUN_CUDA:

        tensor = tensor.to("cuda")

        py_method = py_method.cuda()
        jit_method = torch.jit.script(py_method)

        py_out = py_method(tensor)
        jit_out = jit_method(tensor)

        assert torch.allclose(jit_out, py_out)

Soumith Chintala's avatar
Soumith Chintala committed
44

David Pollack's avatar
David Pollack committed
45
46
class Tester(unittest.TestCase):

47
    # create a sinewave signal for testing
48
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
49
    freq = 440
50
    volume = .3
51
52
53
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
54
    # file for stereo stft test
55
    test_dirpath, test_dir = create_temp_assets_dir()
56
    test_filepath = os.path.join(test_dirpath, 'assets',
57
                                 'steam-train-whistle-daniel_simon.wav')
David Pollack's avatar
David Pollack committed
58

59
60
61
62
63
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
64

65
66
67
68
    def test_scriptmodule_Spectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.Spectrogram, tensor)

69
70
71
72
    def test_scriptmodule_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
        _test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False)

David Pollack's avatar
David Pollack committed
73
74
75
76
    def test_mu_law_companding(self):

        quantization_channels = 256

77
78
79
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
80

81
82
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
83

84
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
85
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
86

87
88
89
90
    def test_scriptmodule_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
        _test_script_module(transforms.AmplitudeToDB, spec)

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def test_batch_AmplitudeToDB(self):
        spec = torch.rand((6, 201))

        # Single then transform then batch
        expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))

114
115
116
117
    def test_scriptmodule_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
        _test_script_module(transforms.MelScale, spec_f)

118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

132
133
134
135
    def test_scriptmodule_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MelSpectrogram, tensor)

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

155
    def test_mel2(self):
PCerles's avatar
PCerles committed
156
        top_db = 80.
157
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
158

159
160
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
161
        mel_transform = transforms.MelSpectrogram()
162
        # check defaults
163
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
164
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
165
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
166
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
167
        # check correctness of filterbank conversion matrix
168
169
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
170
        # check options
171
172
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
173
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
174
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
175
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
176
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
177
178
179
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
180
        # check on multi-channel audio
181
182
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
183
184
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
185
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
186
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
187
        # check filterbank matrix creation
188
189
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
190
191
192
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
193

194
195
196
197
    def test_scriptmodule_MFCC(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MFCC, tensor)

PCerles's avatar
PCerles committed
198
    def test_mfcc(self):
199
200
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
201
202
203
204

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
205
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
206
207
208
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
209
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
210
        self.assertTrue(torch_mfcc.dim() == 3)
211
212
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
213
        # check melkwargs are passed through
214
215
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
216
217
218
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
219
220
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
221
222

        # check norms work correctly
223
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
224
225
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
226
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
227
228

        norm_check = torch_mfcc.clone()
229
230
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
231
232
233

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

234
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
235
    def test_librosa_consistency(self):
236
237
238
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
239
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
240
241

            # test core spectrogram
242
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power)
243
244
245
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
246
                                                                power=power)
247

248
            out_torch = spect_transform(sound).squeeze().cpu()
249
250
251
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
252
253
254
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
255
256
257
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
258
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
259
            torch_mel = melspect_transform(sound).squeeze().cpu()
260

jamarshon's avatar
jamarshon committed
261
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
262
263

            # test s2db
264
265
266
267
268
269
270
271
272
273
274
275
276
            power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
            power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
            power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3))

            mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
            mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
            mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
            self.assertTrue(
                torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
            )

            power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
277
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
278
            db_librosa_tensor = torch.from_numpy(db_librosa)
279
280
281
            self.assertTrue(
                torch.allclose(power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
            )
282
283

            # test MFCC
284
285
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
305
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
306
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
307

jamarshon's avatar
jamarshon committed
308
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

337
338
339
340
341
342
343
344
345
        kwargs4 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 3.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

346
347
        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
348
349
350
        # NOTE Test passes offline, but fails on CircleCI, see #372.
        # _test_librosa_consistency_helper(**kwargs3)
        _test_librosa_consistency_helper(**kwargs4)
PCerles's avatar
PCerles committed
351

Oktai Tatanov's avatar
Oktai Tatanov committed
352
353
    def test_scriptmodule_Resample(self):
        tensor = torch.rand((2, 1000))
354
355
        sample_rate = 100.
        sample_rate_2 = 50.
Oktai Tatanov's avatar
Oktai Tatanov committed
356

357
        _test_script_module(transforms.Resample, tensor, sample_rate, sample_rate_2)
Oktai Tatanov's avatar
Oktai Tatanov committed
358

Vincent QB's avatar
Vincent QB committed
359
360
361
362
363
364
365
366
367
368
369
370
    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Resample()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

371
372
373
374
    def test_scriptmodule_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
        _test_script_module(transforms.ComplexNorm, tensor)

jamarshon's avatar
jamarshon committed
375
376
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
377
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
378
379
380
381
382

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

383
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
384
385
386

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
387
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
388
389

        # we expect the upsampled signal to have twice as many samples
390
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
391
392
393

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
394
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
395
396

        # we expect the downsampled signal to have half as many samples
397
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
398

Vincent QB's avatar
Vincent QB committed
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

Vincent QB's avatar
Vincent QB committed
430
431
432
433
434
435
436
437
438
439
440
441
442
    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

moto's avatar
moto committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    def test_batch_InverseMelScale(self):
        n_fft = 8
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32) ** 2

        # Single then transform then batch
        expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
        self.assertTrue(torch.allclose(computed, expected, atol=1.0))

Vincent QB's avatar
Vincent QB committed
462
463
464
465
466
467
468
469
470
471
472
473
474
    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

475
476
477
478
479
480
481
482
    def test_scriptmodule_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawEncoding, tensor)

    def test_scriptmodule_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawDecoding, tensor)

Vincent QB's avatar
Vincent QB committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

509
510
511
512
513
514
515
516
517
518
519
520
    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
521
522
523
524
525
526
527
528
529
530
531
532
    def test_batch_melspectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

533
534
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
Vincent QB's avatar
Vincent QB committed
535
    def test_batch_mfcc(self):
536
537
538
539
        test_filepath = os.path.join(
            self.test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3'
        )
        waveform, sample_rate = torchaudio.load(test_filepath)
Vincent QB's avatar
Vincent QB committed
540
541
542
543
544
545
546
547
548
549

        # Single then transform then batch
        expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

550
551
552
553
554
    def test_scriptmodule_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
555
        _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)
556

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    def test_batch_TimeStretch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

Tomás Osório's avatar
Tomás Osório committed
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    def test_batch_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_scriptmodule_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        _test_script_module(transforms.Fade, waveform, fade_in_len, fade_out_len)

608
609
    def test_scriptmodule_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
610
        _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)
611
612
613

    def test_scriptmodule_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
614
        _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
615

Tomás Osório's avatar
Tomás Osório committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
    def test_scriptmodule_Vol(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        _test_script_module(transforms.Vol, waveform, 1.1)

    def test_batch_Vol(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
633

moto's avatar
moto committed
634
635
636
637
638
639
class TestLibrosaConsistency(unittest.TestCase):
    test_dirpath = None
    test_dir = None

    @classmethod
    def setUpClass(cls):
640
        cls.test_dirpath, cls.test_dir = create_temp_assets_dir()
moto's avatar
moto committed
641
642
643
644
645
646
647
648
649
650
651

    def _to_librosa(self, sound):
        return sound.cpu().numpy().squeeze()

    def _get_sample_data(self, *asset_paths, **kwargs):
        file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths)

        sound, sample_rate = torchaudio.load(file_path, **kwargs)
        return sound.mean(dim=0, keepdim=True), sample_rate

    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
652
653
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
moto's avatar
moto committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    def test_MelScale(self):
        """MelScale transform is comparable to that of librosa"""
        n_fft = 2048
        n_mels = 256
        hop_length = n_fft // 4

        # Prepare spectrogram input. We use torchaudio to compute one.
        sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3')
        spec_ta = F.spectrogram(
            sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
            hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
        spec_lr = spec_ta.cpu().numpy().squeeze()
        # Perform MelScale with torchaudio and librosa
        melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
        melspec_lr = librosa.feature.melspectrogram(
            S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
            win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
        # Note: Using relaxed rtol instead of atol
        assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)

    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
    def test_InverseMelScale(self):
        """InverseMelScale transform is comparable to that of librosa"""
        n_fft = 2048
        n_mels = 256
        n_stft = n_fft // 2 + 1
        hop_length = n_fft // 4

        # Prepare mel spectrogram input. We use torchaudio to compute one.
        sound, sample_rate = self._get_sample_data(
            'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
        spec_orig = F.spectrogram(
            sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
            hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
        melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
        melspec_lr = melspec_ta.cpu().numpy().squeeze()
        # Perform InverseMelScale with torch audio and librosa
        spec_ta = transforms.InverseMelScale(
            n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
        spec_lr = librosa.feature.inverse.mel_to_stft(
            melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
        spec_lr = torch.from_numpy(spec_lr[None, ...])

        # Align dimensions
        # librosa does not return power spectrogram while torchaudio returns power spectrogram
        spec_orig = spec_orig.sqrt()
        spec_ta = spec_ta.sqrt()

        threshold = 2.0
        # This threshold was choosen empirically, based on the following observation
        #
        # torch.dist(spec_lr, spec_ta, p=float('inf'))
        # >>> tensor(1.9666)
        #
        # The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise.
        # This is because they use different approximation algorithms and resulting values can live
        # in different magnitude. (although most of them are very close)
        # See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
        # See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
        # distance over frequencies.
        assert torch.allclose(spec_ta, spec_lr, atol=threshold)

        threshold = 1700.0
        # This threshold was choosen empirically, based on the following observations
        #
        # torch.dist(spec_orig, spec_ta, p=1)
        # >>> tensor(1644.3516)
        # torch.dist(spec_orig, spec_lr, p=1)
        # >>> tensor(1420.7103)
        # torch.dist(spec_lr, spec_ta, p=1)
        # >>> tensor(943.2759)
        assert torch.dist(spec_orig, spec_ta, p=1) < threshold


David Pollack's avatar
David Pollack committed
728
729
if __name__ == '__main__':
    unittest.main()