test_functional.py 28.4 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
jamarshon's avatar
jamarshon committed
2
import math
Vincent QB's avatar
Vincent QB committed
3
import os
jamarshon's avatar
jamarshon committed
4
5
6

import torch
import torchaudio
7
import torchaudio.functional as F
8
import torchaudio.transforms as T
9
import pytest
jamarshon's avatar
jamarshon committed
10
import unittest
11
import common_utils
jamarshon's avatar
jamarshon committed
12

13
14
15
16
17
18
from torchaudio.common_utils import IMPORT_LIBROSA

if IMPORT_LIBROSA:
    import numpy as np
    import librosa

jamarshon's avatar
jamarshon committed
19

Vincent QB's avatar
Vincent QB committed
20
def _test_torchscript_functional_shape(py_method, *args, **kwargs):
21
22
23
24
25
    jit_method = torch.jit.script(py_method)

    jit_out = jit_method(*args, **kwargs)
    py_out = py_method(*args, **kwargs)

Vincent QB's avatar
Vincent QB committed
26
27
28
29
30
31
    assert jit_out.shape == py_out.shape
    return jit_out, py_out


def _test_torchscript_functional(py_method, *args, **kwargs):
    jit_out, py_out = _test_torchscript_functional_shape(py_method, *args, **kwargs)
32
33
34
    assert torch.allclose(jit_out, py_out)


jamarshon's avatar
jamarshon committed
35
36
37
class TestFunctional(unittest.TestCase):
    data_sizes = [(2, 20), (3, 15), (4, 10)]
    number_of_trials = 100
Vincent QB's avatar
Vincent QB committed
38
39
    specgram = torch.tensor([1., 2., 3., 4.])

Vincent QB's avatar
Vincent QB committed
40
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
41

Vincent QB's avatar
Vincent QB committed
42
43
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
44
    waveform_train, sr_train = torchaudio.load(test_filepath)
Vincent QB's avatar
Vincent QB committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def test_torchscript_spectrogram(self):

        tensor = torch.rand((1, 1000))
        n_fft = 400
        ws = 400
        hop = 200
        pad = 0
        window = torch.hann_window(ws)
        power = 2
        normalize = False

        _test_torchscript_functional(
            F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
        )

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    def test_torchscript_griffinlim(self):
        tensor = torch.rand((1, 201, 6))
        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000
        init = 0

        _test_torchscript_functional(
            F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
        )

    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available')
    def test_griffinlim(self):
80
81
82
83

        # NOTE: This test is flaky without a fixed random seed
        # See https://github.com/pytorch/audio/issues/382
        torch.random.manual_seed(42)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        tensor = torch.rand((1, 1000))

        n_fft = 400
        ws = 400
        hop = 100
        window = torch.hann_window(ws)
        normalize = False
        momentum = 0.99
        n_iter = 8
        length = 1000
        rand_init = False
        init = 'random' if rand_init else None

        specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2, normalize).sqrt()
        ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
                              n_iter, momentum, length, rand_init)
        lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(), n_iter=n_iter, hop_length=hop,
                                    momentum=momentum, init=init, length=length)
        lr_out = torch.from_numpy(lr_out).unsqueeze(0)

        self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))

Vincent QB's avatar
Vincent QB committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    def test_batch_griffinlim(self):

        torch.random.manual_seed(42)
        tensor = torch.rand((1, 201, 6))

        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000

        self._test_batch(
            F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
        )

Vincent QB's avatar
Vincent QB committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
        computed = F.compute_deltas(specgram, win_length=win_length)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)

    def test_compute_deltas_onechannel(self):
        specgram = self.specgram.unsqueeze(0).unsqueeze(0)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_twochannel(self):
        specgram = self.specgram.repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_randn(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        computed = F.compute_deltas(specgram, win_length=win_length)
Vincent QB's avatar
Vincent QB committed
148

Vincent QB's avatar
Vincent QB committed
149
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
Vincent QB's avatar
Vincent QB committed
150

151
        _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
jamarshon's avatar
jamarshon committed
152

Vincent QB's avatar
Vincent QB committed
153
154
    def test_batch_pitch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
Vincent QB's avatar
Vincent QB committed
155
        self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
Vincent QB's avatar
Vincent QB committed
156

Vincent QB's avatar
Vincent QB committed
157
158
    def test_jit_pitch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
159
        _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
Vincent QB's avatar
Vincent QB committed
160

jamarshon's avatar
jamarshon committed
161
162
163
164
165
166
167
168
169
170
171
172
    def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
        # trim sound for case when constructed signal is shorter than original
        sound = sound[..., :estimate.size(-1)]

        self.assertTrue(sound.shape == estimate.shape, (sound.shape, estimate.shape))
        self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))

    def _test_istft_is_inverse_of_stft(self, kwargs):
        # generates a random sound signal for each tril and then does the stft/istft
        # operation to check whether we can reconstruct signal
        for data_size in self.data_sizes:
            for i in range(self.number_of_trials):
Vincent QB's avatar
Vincent QB committed
173

174
                sound = common_utils.random_float_tensor(i, data_size)
jamarshon's avatar
jamarshon committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

                stft = torch.stft(sound, **kwargs)
                estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)

                self._compare_estimate(sound, estimate)

    def test_istft_is_inverse_of_stft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'hop_length': 4,
            'win_length': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs1)

    def test_istft_is_inverse_of_stft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'hop_length': 2,
            'win_length': 8,
            'window': torch.hann_window(8),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs2)

    def test_istft_is_inverse_of_stft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 15,
            'hop_length': 3,
            'win_length': 11,
            'window': torch.hamming_window(11),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs3)

    def test_istft_is_inverse_of_stft4(self):
        # hamming_window, not centered, not normalized, onesided
        # window same size as n_fft
        kwargs4 = {
            'n_fft': 5,
            'hop_length': 2,
            'win_length': 5,
            'window': torch.hamming_window(5),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs4)

    def test_istft_is_inverse_of_stft5(self):
        # hamming_window, not centered, not normalized, not onesided
        # window same size as n_fft
        kwargs5 = {
            'n_fft': 3,
            'hop_length': 2,
            'win_length': 3,
            'window': torch.hamming_window(3),
            'center': False,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs5)

    def test_istft_of_ones(self):
        # stft = torch.stft(torch.ones(4), 4)
        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        ])

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.ones(4), estimate)

    def test_istft_of_zeros(self):
        # stft = torch.stft(torch.zeros(4), 4)
        stft = torch.zeros((3, 5, 2))

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.zeros(4), estimate)

    def test_istft_requires_overlap_windows(self):
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        stft = torch.zeros((3, 5, 2))
        self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4,
                          hop_length=20, win_length=1, window=torch.ones(1))

    def test_istft_requires_nola(self):
        stft = torch.zeros((3, 5, 2))
        kwargs_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.ones(4),
        }

        kwargs_not_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.zeros(4),
        }

        # A window of ones meets NOLA but a window of zeros does not. This should
        # throw an error.
        torchaudio.functional.istft(stft, **kwargs_ok)
        self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok)

    def test_istft_requires_non_empty(self):
        self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
        self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)

    def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
                                               window=torch.ones(L), center=False, normalized=False)
        # There is a larger error due to the scaling of amplitude
        self._compare_estimate(sound, estimate, atol=1e-3)

    def test_istft_of_sine(self):
        self._test_istft_of_sine(amplitude=123, L=5, n=1)
        self._test_istft_of_sine(amplitude=150, L=5, n=2)
        self._test_istft_of_sine(amplitude=111, L=5, n=3)
        self._test_istft_of_sine(amplitude=160, L=7, n=4)
        self._test_istft_of_sine(amplitude=145, L=8, n=5)
        self._test_istft_of_sine(amplitude=80, L=9, n=6)
        self._test_istft_of_sine(amplitude=99, L=10, n=7)

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8):
        for i in range(self.number_of_trials):
            tensor1 = common_utils.random_float_tensor(i, data_size)
            tensor2 = common_utils.random_float_tensor(i * 2, data_size)
            a, b = torch.rand(2)
            istft1 = torchaudio.functional.istft(tensor1, **kwargs)
            istft2 = torchaudio.functional.istft(tensor2, **kwargs)
            istft = a * istft1 + b * istft2
            estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs)
            self._compare_estimate(istft, estimate, atol, rtol)

    def test_linearity_of_istft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        data_size = (2, 7, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs1)

    def test_linearity_of_istft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs2)

    def test_linearity_of_istft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs3)

    def test_linearity_of_istft4(self):
        # hamming_window, not centered, not normalized, onesided
        kwargs4 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }
        data_size = (2, 7, 3, 2)
        self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)

Vincent QB's avatar
Vincent QB committed
397
398
399
400
401
402
403
404
405
406
    def test_batch_istft(self):

        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        ])

        self._test_batch(F.istft, stft, n_fft=4, length=4)

engineerchuan's avatar
engineerchuan committed
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
        # Using a decorator here causes parametrize to fail on Python 2
        if not IMPORT_LIBROSA:
            raise unittest.SkipTest('Librosa is not available')

        librosa_fb = librosa.filters.mel(sr=sample_rate,
                                         n_fft=n_fft,
                                         n_mels=n_mels,
                                         fmax=fmax,
                                         fmin=fmin,
                                         htk=True,
                                         norm=None)
        fb = F.create_fb_matrix(sample_rate=sample_rate,
                                n_mels=n_mels,
                                f_max=fmax,
                                f_min=fmin,
                                n_freqs=(n_fft // 2 + 1))

        for i_mel_bank in range(n_mels):
            assert torch.allclose(fb[:, i_mel_bank], torch.tensor(librosa_fb[i_mel_bank]), atol=1e-4)

    def test_create_fb(self):
        self._test_create_fb()
        self._test_create_fb(n_mels=128, sample_rate=44100)
        self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
        self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
        self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
        self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
        self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    def test_gain(self):
        waveform_gain = F.gain(self.waveform_train, 3)
        self.assertTrue(waveform_gain.abs().max().item(), 1.)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", [3])
        sox_gain_waveform = E.sox_build_flow_effects()[0]

        self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04))

    def test_dither(self):
        waveform_dithered = F.dither(self.waveform_train)
        waveform_dithered_noiseshaped = F.dither(self.waveform_train, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("dither", [])
        sox_dither_waveform = E.sox_build_flow_effects()[0]

        self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04))
        E.clear_chain()

        E.append_effect_to_chain("dither", ["-s"])
        sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

        self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02))

    def test_vctk_transform_pipeline(self):
        test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav")
        wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

        # rate
        sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
        wf_vctk = sample(wf_vctk)
        # dither
        wf_vctk = F.dither(wf_vctk, noise_shaping=True)
Vincent QB's avatar
Vincent QB committed
474

475
476
477
478
479
480
481
482
483
484
485
486
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath_vctk)
        E.append_effect_to_chain("gain", ["-h"])
        E.append_effect_to_chain("channels", [1])
        E.append_effect_to_chain("rate", [16000])
        E.append_effect_to_chain("gain", ["-rh"])
        E.append_effect_to_chain("dither", ["-s"])
        wf_vctk_sox = E.sox_build_flow_effects()[0]

        self.assertTrue(torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))

    def test_pitch(self):
Vincent QB's avatar
Vincent QB committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        test_dirpath, test_dir = common_utils.create_temp_assets_dir()
        test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
        test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav")

        # Files from https://www.mediacollege.com/audio/tone/download/
        tests = [
            (test_filepath_100, 100),
            (test_filepath_440, 440),
        ]

        for filename, freq_ref in tests:
            waveform, sample_rate = torchaudio.load(filename)

            freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)

            threshold = 1
            s = ((freq - freq_ref).abs() > threshold).sum()
            self.assertFalse(s)

Vincent QB's avatar
Vincent QB committed
506
            # Convert to stereo and batch for testing purposes
Vincent QB's avatar
Vincent QB committed
507
            self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
Vincent QB's avatar
Vincent QB committed
508

Vincent QB's avatar
Vincent QB committed
509
    def _test_batch_shape(self, functional, tensor, *args, **kwargs):
Vincent QB's avatar
Vincent QB committed
510

Vincent QB's avatar
Vincent QB committed
511
512
513
514
515
        kwargs_compare = {}
        if 'atol' in kwargs:
            atol = kwargs['atol']
            del kwargs['atol']
            kwargs_compare['atol'] = atol
Vincent QB's avatar
Vincent QB committed
516

Vincent QB's avatar
Vincent QB committed
517
518
519
520
        if 'rtol' in kwargs:
            rtol = kwargs['rtol']
            del kwargs['rtol']
            kwargs_compare['rtol'] = rtol
Vincent QB's avatar
Vincent QB committed
521
522
523

        # Single then transform then batch

Vincent QB's avatar
Vincent QB committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
        torch.random.manual_seed(42)
        expected = functional(tensor.clone(), *args, **kwargs)
        expected = expected.unsqueeze(0).unsqueeze(0)

        # 1-Batch then transform

        tensors = tensor.unsqueeze(0).unsqueeze(0)

        torch.random.manual_seed(42)
        computed = functional(tensors.clone(), *args, **kwargs)

        self._compare_estimate(computed, expected, **kwargs_compare)

        return tensors, expected

    def _test_batch(self, functional, tensor, *args, **kwargs):

        tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)

        kwargs_compare = {}
        if 'atol' in kwargs:
            atol = kwargs['atol']
            del kwargs['atol']
            kwargs_compare['atol'] = atol

        if 'rtol' in kwargs:
            rtol = kwargs['rtol']
            del kwargs['rtol']
            kwargs_compare['rtol'] = rtol

        # 3-Batch then transform

        ind = [3] + [1] * (int(tensors.dim()) - 1)
        tensors = tensor.repeat(*ind)

        ind = [3] + [1] * (int(expected.dim()) - 1)
        expected = expected.repeat(*ind)

        torch.random.manual_seed(42)
        computed = functional(tensors.clone(), *args, **kwargs)
Vincent QB's avatar
Vincent QB committed
564

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def test_torchscript_create_fb_matrix(self):

        n_stft = 100
        f_min = 0.0
        f_max = 20.0
        n_mels = 10
        sample_rate = 16000

        _test_torchscript_functional(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)

    def test_torchscript_amplitude_to_DB(self):
        spec = torch.rand((6, 201))
        multiplier = 10.0
        amin = 1e-10
        db_multiplier = 0.0
        top_db = 80.0

        _test_torchscript_functional(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    def test_torchscript_DB_to_amplitude(self):

        x = torch.rand((1, 100))
        ref = 1.
        power = 1.

        _test_torchscript_functional(F.DB_to_amplitude, x, ref, power)

    def test_DB_to_amplitude(self):
        # Make some noise
        x = torch.rand(1000)
        spectrogram = torchaudio.transforms.Spectrogram()
        spec = spectrogram(x)

        amin = 1e-10
        ref = 1.0
        db_multiplier = math.log10(max(amin, ref))

        # Waveform amplitude -> DB -> amplitude
        multiplier = 20.
        power = 0.5

        db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))

        # Spectrogram amplitude -> DB -> amplitude
        db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(spec, x2, atol=5e-5))

        # Waveform power -> DB -> power
        multiplier = 10.
        power = 1.

        db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))

        # Spectrogram power -> DB -> power
        db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(spec, x2, atol=5e-5))

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available')
    def test_amplitude_to_DB(self):
        spec = torch.rand((6, 201))

        amin = 1e-10
        db_multiplier = 0.0
        top_db = 80.0

        # Power to DB
        multiplier = 10.0

        ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
        lr_out = librosa.core.power_to_db(spec.numpy())
        lr_out = torch.from_numpy(lr_out).unsqueeze(0)

        self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))

        # Amplitude to DB
        multiplier = 20.0

        ta_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
        lr_out = librosa.core.amplitude_to_db(spec.numpy())
        lr_out = torch.from_numpy(lr_out).unsqueeze(0)

        self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
    def test_torchscript_create_dct(self):

        n_mfcc = 40
        n_mels = 128
        norm = "ortho"

        _test_torchscript_functional(F.create_dct, n_mfcc, n_mels, norm)

    def test_torchscript_mu_law_encoding(self):

        tensor = torch.rand((1, 10))
        qc = 256

        _test_torchscript_functional(F.mu_law_encoding, tensor, qc)

    def test_torchscript_mu_law_decoding(self):

        tensor = torch.rand((1, 10))
        qc = 256

        _test_torchscript_functional(F.mu_law_decoding, tensor, qc)

    def test_torchscript_complex_norm(self):

Vincent QB's avatar
Vincent QB committed
682
        complex_tensor = torch.randn(1, 2, 1025, 400, 2)
683
684
685
686
687
688
        power = 2

        _test_torchscript_functional(F.complex_norm, complex_tensor, power)

    def test_mask_along_axis(self):

Vincent QB's avatar
Vincent QB committed
689
        specgram = torch.randn(2, 1025, 400)
690
691
692
693
694
695
696
697
        mask_param = 100
        mask_value = 30.
        axis = 2

        _test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)

    def test_mask_along_axis_iid(self):

Vincent QB's avatar
Vincent QB committed
698
        specgrams = torch.randn(4, 2, 1025, 400)
699
700
701
702
703
704
        mask_param = 100
        mask_value = 30.
        axis = 2

        _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)

705
706
707
708
709
710
711
    def test_torchscript_gain(self):
        tensor = torch.rand((1, 1000))
        gainDB = 2.0

        _test_torchscript_functional(F.gain, tensor, gainDB)

    def test_torchscript_dither(self):
Vincent QB's avatar
Vincent QB committed
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
        tensor = torch.rand((2, 1000))

        _test_torchscript_functional_shape(F.dither, tensor)
        _test_torchscript_functional_shape(F.dither, tensor, "RPDF")
        _test_torchscript_functional_shape(F.dither, tensor, "GPDF")


def _num_stft_bins(signal_len, fft_len, hop_length, pad):
    return (signal_len + 2 * pad - fft_len + hop_length) // hop_length


@pytest.mark.parametrize('complex_specgrams', [
    torch.randn(2, 1025, 400, 2)
])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256])
def test_phase_vocoder(complex_specgrams, rate, hop_length):

    # Using a decorator here causes parametrize to fail on Python 2
    if not IMPORT_LIBROSA:
        raise unittest.SkipTest('Librosa is not available')

    # Due to cummulative sum, numerical error in using torch.float32 will
    # result in bottom right values of the stretched sectrogram to not
    # match with librosa.

    complex_specgrams = complex_specgrams.type(torch.float64)
    phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None]

    complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance)

    # == Test shape
    expected_size = list(complex_specgrams.size())
    expected_size[-2] = int(np.ceil(expected_size[-2] / rate))

    assert complex_specgrams.dim() == complex_specgrams_stretch.dim()
    assert complex_specgrams_stretch.size() == torch.Size(expected_size)

    # == Test values
    index = [0] * (complex_specgrams.dim() - 3) + [slice(None)] * 3
    mono_complex_specgram = complex_specgrams[index].numpy()
    mono_complex_specgram = mono_complex_specgram[..., 0] + \
        mono_complex_specgram[..., 1] * 1j
    expected_complex_stretch = librosa.phase_vocoder(mono_complex_specgram,
                                                     rate=rate,
                                                     hop_length=hop_length)

    complex_stretch = complex_specgrams_stretch[index].numpy()
    complex_stretch = complex_stretch[..., 0] + 1j * complex_stretch[..., 1]
761

Vincent QB's avatar
Vincent QB committed
762
    assert np.allclose(complex_stretch, expected_complex_stretch, atol=1e-5)
763

764
765
766
767
768
769
770
771
772
773
774
775
776

@pytest.mark.parametrize('complex_tensor', [
    torch.randn(1, 2, 1025, 400, 2),
    torch.randn(1025, 400, 2)
])
@pytest.mark.parametrize('power', [1, 2, 0.7])
def test_complex_norm(complex_tensor, power):
    expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
    norm_tensor = F.complex_norm(complex_tensor, power)

    assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)


777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
@pytest.mark.parametrize('specgram', [
    torch.randn(2, 1025, 400),
    torch.randn(1, 201, 100)
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [1, 2])
def test_mask_along_axis(specgram, mask_param, mask_value, axis):

    mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)

    other_axis = 1 if axis == 2 else 2

    masked_columns = (mask_specgram == mask_value).sum(other_axis)
    num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
    num_masked_columns /= mask_specgram.size(0)

    assert mask_specgram.size() == specgram.size()
    assert num_masked_columns < mask_param


@pytest.mark.parametrize('specgrams', [
    torch.randn(4, 2, 1025, 400),
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [2, 3])
def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis):

    mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

    other_axis = 2 if axis == 3 else 3

    masked_columns = (mask_specgrams == mask_value).sum(other_axis)
    num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)

    assert mask_specgrams.size() == specgrams.size()
    assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()


jamarshon's avatar
jamarshon committed
817
818
if __name__ == '__main__':
    unittest.main()