test_torchscript_consistency.py 17 KB
Newer Older
1
2
3
4
5
6
7
"""Test suites for jit-ability and its numerical compatibility"""
import os
import unittest

import torch
import torchaudio
import torchaudio.functional as F
8
import torchaudio.transforms as T
9
10
11
12

import common_utils


13
14
15
16
17
18
19
20
21
22
23
24
def _assert_functional_consistency(func, tensor, device, shape_only=False):
    tensor = tensor.to(device)
    ts_func = torch.jit.script(func)
    output = func(tensor)
    ts_output = ts_func(tensor)

    if shape_only:
        assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
    else:
        torch.testing.assert_allclose(ts_output, output)


25
26
27
28
29
30
31
32
33
def _assert_transforms_consistency(transform, tensor, device):
    tensor = tensor.to(device)
    transform = transform.to(device)
    ts_transform = torch.jit.script(transform)
    output = transform(tensor)
    ts_output = ts_transform(tensor)
    torch.testing.assert_allclose(ts_output, output)


34
35
36
class _FunctionalTestMixin:
    """Implements test for `functinoal` modul that are performed for different devices"""
    device = None
37

38
39
    def _assert_consistency(self, func, tensor, shape_only=False):
        return _assert_functional_consistency(func, tensor, self.device, shape_only=shape_only)
40
41

    def test_spectrogram(self):
42
43
44
45
46
47
48
49
50
51
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            pad = 0
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

52
        tensor = torch.rand((1, 1000))
53
        self._assert_consistency(func, tensor)
54
55

    def test_griffinlim(self):
56
57
58
59
60
61
62
63
64
65
66
67
68
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            momentum = 0.99
            n_iter = 32
            length = 1000
            rand_int = False
            return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)

69
        tensor = torch.rand((1, 201, 6))
70
        self._assert_consistency(func, tensor)
71
72

    def test_compute_deltas(self):
73
74
75
76
        def func(tensor):
            win_length = 2 * 7 + 1
            return F.compute_deltas(tensor, win_length=win_length)

77
78
79
        channel = 13
        n_mfcc = channel * 3
        time = 1021
80
81
        tensor = torch.randn(channel, n_mfcc, time)
        self._assert_consistency(func, tensor)
82
83
84
85

    def test_detect_pitch_frequency(self):
        filepath = os.path.join(
            common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
86
87
88
89
90
91
92
        waveform, _ = torchaudio.load(filepath)

        def func(tensor):
            sample_rate = 44100
            return F.detect_pitch_frequency(tensor, sample_rate)

        self._assert_consistency(func, waveform)
93
94

    def test_create_fb_matrix(self):
95
96
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')
97

98
99
100
101
102
103
104
105
106
107
        def func(_):
            n_stft = 100
            f_min = 0.0
            f_max = 20.0
            n_mels = 10
            sample_rate = 16000
            return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)

        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
108
109

    def test_amplitude_to_DB(self):
110
111
112
113
114
115
        def func(tensor):
            multiplier = 10.0
            amin = 1e-10
            db_multiplier = 0.0
            top_db = 80.0
            return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
116

117
118
        tensor = torch.rand((6, 201))
        self._assert_consistency(func, tensor)
119
120

    def test_DB_to_amplitude(self):
121
122
123
124
        def func(tensor):
            ref = 1.
            power = 1.
            return F.DB_to_amplitude(tensor, ref, power)
125

126
127
        tensor = torch.rand((1, 100))
        self._assert_consistency(func, tensor)
128
129

    def test_create_dct(self):
130
131
132
133
134
135
136
137
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')

        def func(_):
            n_mfcc = 40
            n_mels = 128
            norm = "ortho"
            return F.create_dct(n_mfcc, n_mels, norm)
138

139
140
        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
141
142

    def test_mu_law_encoding(self):
143
144
145
        def func(tensor):
            qc = 256
            return F.mu_law_encoding(tensor, qc)
146

147
148
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
149
150

    def test_mu_law_decoding(self):
151
152
153
        def func(tensor):
            qc = 256
            return F.mu_law_decoding(tensor, qc)
154

155
156
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
157
158

    def test_complex_norm(self):
159
160
161
        def func(tensor):
            power = 2.
            return F.complex_norm(tensor, power)
162

163
164
        tensor = torch.randn(1, 2, 1025, 400, 2)
        _assert_functional_consistency(func, tensor, self.device)
165
166

    def test_mask_along_axis(self):
167
168
169
170
171
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis(tensor, mask_param, mask_value, axis)
172

173
174
        tensor = torch.randn(2, 1025, 400)
        self._assert_consistency(func, tensor)
175
176

    def test_mask_along_axis_iid(self):
177
178
179
180
181
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
182

183
184
        tensor = torch.randn(4, 2, 1025, 400)
        self._assert_consistency(func, tensor)
185
186

    def test_gain(self):
187
188
189
190
        def func(tensor):
            gainDB = 2.0
            return F.gain(tensor, gainDB)

191
        tensor = torch.rand((1, 1000))
192
193
194
195
196
197
198
199
        self._assert_consistency(func, tensor)

    def test_dither_TPDF(self):
        def func(tensor):
            return F.dither(tensor, 'TPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
200

201
202
203
    def test_dither_RPDF(self):
        def func(tensor):
            return F.dither(tensor, 'RPDF')
204
205

        tensor = torch.rand((2, 1000))
206
        self._assert_consistency(func, tensor, shape_only=True)
207

208
209
210
211
212
213
    def test_dither_GPDF(self):
        def func(tensor):
            return F.dither(tensor, 'GPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
214

215
216
217
218
    def test_lfilter(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
        waveform, _ = torchaudio.load(filepath, normalization=True)

219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
255
256
257

    def test_lowpass(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
258
        waveform, _ = torchaudio.load(filepath, normalization=True)
259

260
261
262
263
264
265
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 3000.
            return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
266

267
    def test_highpass(self):
268
        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
269
        waveform, _ = torchaudio.load(filepath, normalization=True)
270

271
272
273
274
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
275

276
277
278
        self._assert_consistency(func, waveform)

    def test_allpass(self):
279
        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
280
        waveform, _ = torchaudio.load(filepath, normalization=True)
281

282
283
284
285
286
287
288
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.allpass_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
289

290
    def test_bandpass_with_csg(self):
291
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
292
        waveform, _ = torchaudio.load(filepath, normalization=True)
293

294
295
296
297
298
299
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
300

301
302
303
        self._assert_consistency(func, waveform)

    def test_bandpass_withou_csg(self):
304
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
305
        waveform, _ = torchaudio.load(filepath, normalization=True)
306

307
308
309
310
311
312
313
314
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)

        self._assert_consistency(func, waveform)
315

316
    def test_bandreject(self):
317
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
318
        waveform, _ = torchaudio.load(filepath, normalization=True)
319

320
321
322
323
324
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
325

326
327
328
        self._assert_consistency(func, waveform)

    def test_band_with_noise(self):
329
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
330
        waveform, _ = torchaudio.load(filepath, normalization=True)
331

332
333
334
335
336
337
338
339
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = True
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
340

341
    def test_band_without_noise(self):
342
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
343
        waveform, _ = torchaudio.load(filepath, normalization=True)
344

345
346
347
348
349
350
351
352
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = False
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
353

354
    def test_treble(self):
355
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
356
357
358
359
360
361
362
363
364
365
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)
366
367
368

    def test_deemph(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
369
370
371
372
373
374
375
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            return F.deemph_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
376
377
378

    def test_riaa(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
379
        waveform, _ = torchaudio.load(filepath, normalization=True)
380

381
382
383
384
385
        def func(tensor):
            sample_rate = 44100
            return F.riaa_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
386

387
    def test_equalizer(self):
388
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
389
390
391
392
393
394
395
396
397
398
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)

        self._assert_consistency(func, waveform)
399
400
401
402

    def test_perf_biquad_filtering(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)
403
404
405
406
407
408
409

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
410

411

412
413
414
class _TransformsTestMixin:
    """Implements test for Transforms that are performed for different devices"""
    device = None
415

416
417
    def _assert_consistency(self, transform, tensor):
        _assert_transforms_consistency(transform, tensor, self.device)
418
419
420

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
421
        self._assert_consistency(T.Spectrogram(), tensor)
422
423
424

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
425
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
426
427
428

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
429
        self._assert_consistency(T.AmplitudeToDB(), spec)
430
431
432

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
433
        self._assert_consistency(T.MelScale(), spec_f)
434
435
436

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
437
        self._assert_consistency(T.MelSpectrogram(), tensor)
438
439
440

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
441
        self._assert_consistency(T.MFCC(), tensor)
442
443
444
445
446

    def test_Resample(self):
        tensor = torch.rand((2, 1000))
        sample_rate = 100.
        sample_rate_2 = 50.
447
        self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
448
449
450

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
451
        self._assert_consistency(T.ComplexNorm(), tensor)
452
453
454

    def test_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
455
        self._assert_consistency(T.MuLawEncoding(), tensor)
456
457
458

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
459
        self._assert_consistency(T.MuLawDecoding(), tensor)
460
461
462
463
464
465

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
466
467
468
469
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
470
471
472
473
474
475
476

    def test_Fade(self):
        test_filepath = os.path.join(
            common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000
477
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
478
479
480

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
481
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
482
483
484

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
485
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
486
487
488
489
490

    def test_Vol(self):
        test_filepath = os.path.join(
            common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)
491
492
493
        self._assert_consistency(T.Vol(1.1), waveform)


494
495
496
497
498
499
500
501
502
503
504
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
    """Test suite for Functional module on CPU"""
    device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestFunctionalCUDA(_FunctionalTestMixin, unittest.TestCase):
    """Test suite for Functional module on GPU"""
    device = torch.device('cuda')


505
506
507
508
509
510
511
512
513
class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
    """Test suite for Transforms module on CPU"""
    device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
    """Test suite for Transforms module on GPU"""
    device = torch.device('cuda')
Vincent QB's avatar
Vincent QB committed
514
515
516
517


if __name__ == '__main__':
    unittest.main()