test_torchscript_consistency.py 19.4 KB
Newer Older
1
2
3
4
5
6
"""Test suites for jit-ability and its numerical compatibility"""
import unittest

import torch
import torchaudio
import torchaudio.functional as F
7
import torchaudio.transforms as T
8
9
10
11

import common_utils


12
13
14
15
16
17
18
19
20
21
22
23
def _assert_functional_consistency(func, tensor, device, shape_only=False):
    tensor = tensor.to(device)
    ts_func = torch.jit.script(func)
    output = func(tensor)
    ts_output = ts_func(tensor)

    if shape_only:
        assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
    else:
        torch.testing.assert_allclose(ts_output, output)


24
25
26
27
28
29
30
31
32
def _assert_transforms_consistency(transform, tensor, device):
    tensor = tensor.to(device)
    transform = transform.to(device)
    ts_transform = torch.jit.script(transform)
    output = transform(tensor)
    ts_output = ts_transform(tensor)
    torch.testing.assert_allclose(ts_output, output)


33
34
35
class _FunctionalTestMixin:
    """Implements test for `functinoal` modul that are performed for different devices"""
    device = None
36

37
38
    def _assert_consistency(self, func, tensor, shape_only=False):
        return _assert_functional_consistency(func, tensor, self.device, shape_only=shape_only)
39
40

    def test_spectrogram(self):
41
42
43
44
45
46
47
48
49
50
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            pad = 0
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

51
        tensor = torch.rand((1, 1000))
52
        self._assert_consistency(func, tensor)
53
54

    def test_griffinlim(self):
55
56
57
58
59
60
61
62
63
64
65
66
67
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            momentum = 0.99
            n_iter = 32
            length = 1000
            rand_int = False
            return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)

68
        tensor = torch.rand((1, 201, 6))
69
        self._assert_consistency(func, tensor)
70
71

    def test_compute_deltas(self):
72
73
74
75
        def func(tensor):
            win_length = 2 * 7 + 1
            return F.compute_deltas(tensor, win_length=win_length)

76
77
78
        channel = 13
        n_mfcc = channel * 3
        time = 1021
79
80
        tensor = torch.randn(channel, n_mfcc, time)
        self._assert_consistency(func, tensor)
81
82

    def test_detect_pitch_frequency(self):
83
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')
84
85
86
87
88
89
90
        waveform, _ = torchaudio.load(filepath)

        def func(tensor):
            sample_rate = 44100
            return F.detect_pitch_frequency(tensor, sample_rate)

        self._assert_consistency(func, waveform)
91
92

    def test_create_fb_matrix(self):
93
94
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')
95

96
97
98
99
100
101
102
103
104
105
        def func(_):
            n_stft = 100
            f_min = 0.0
            f_max = 20.0
            n_mels = 10
            sample_rate = 16000
            return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)

        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
106
107

    def test_amplitude_to_DB(self):
108
109
110
111
112
113
        def func(tensor):
            multiplier = 10.0
            amin = 1e-10
            db_multiplier = 0.0
            top_db = 80.0
            return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
114

115
116
        tensor = torch.rand((6, 201))
        self._assert_consistency(func, tensor)
117
118

    def test_DB_to_amplitude(self):
119
120
121
122
        def func(tensor):
            ref = 1.
            power = 1.
            return F.DB_to_amplitude(tensor, ref, power)
123

124
125
        tensor = torch.rand((1, 100))
        self._assert_consistency(func, tensor)
126
127

    def test_create_dct(self):
128
129
130
131
132
133
134
135
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')

        def func(_):
            n_mfcc = 40
            n_mels = 128
            norm = "ortho"
            return F.create_dct(n_mfcc, n_mels, norm)
136

137
138
        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
139
140

    def test_mu_law_encoding(self):
141
142
143
        def func(tensor):
            qc = 256
            return F.mu_law_encoding(tensor, qc)
144

145
146
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
147
148

    def test_mu_law_decoding(self):
149
150
151
        def func(tensor):
            qc = 256
            return F.mu_law_decoding(tensor, qc)
152

153
154
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
155
156

    def test_complex_norm(self):
157
158
159
        def func(tensor):
            power = 2.
            return F.complex_norm(tensor, power)
160

161
162
        tensor = torch.randn(1, 2, 1025, 400, 2)
        _assert_functional_consistency(func, tensor, self.device)
163
164

    def test_mask_along_axis(self):
165
166
167
168
169
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis(tensor, mask_param, mask_value, axis)
170

171
172
        tensor = torch.randn(2, 1025, 400)
        self._assert_consistency(func, tensor)
173
174

    def test_mask_along_axis_iid(self):
175
176
177
178
179
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
180

181
182
        tensor = torch.randn(4, 2, 1025, 400)
        self._assert_consistency(func, tensor)
183
184

    def test_gain(self):
185
186
187
188
        def func(tensor):
            gainDB = 2.0
            return F.gain(tensor, gainDB)

189
        tensor = torch.rand((1, 1000))
190
191
192
193
194
195
196
197
        self._assert_consistency(func, tensor)

    def test_dither_TPDF(self):
        def func(tensor):
            return F.dither(tensor, 'TPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
198

199
200
201
    def test_dither_RPDF(self):
        def func(tensor):
            return F.dither(tensor, 'RPDF')
202
203

        tensor = torch.rand((2, 1000))
204
        self._assert_consistency(func, tensor, shape_only=True)
205

206
207
208
209
210
211
    def test_dither_GPDF(self):
        def func(tensor):
            return F.dither(tensor, 'GPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
212

213
    def test_lfilter(self):
214
        filepath = common_utils.get_asset_path('whitenoise.wav')
215
216
        waveform, _ = torchaudio.load(filepath, normalization=True)

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
253
254

    def test_lowpass(self):
255
        filepath = common_utils.get_asset_path('whitenoise.wav')
256
        waveform, _ = torchaudio.load(filepath, normalization=True)
257

258
259
260
261
262
263
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 3000.
            return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
264

265
    def test_highpass(self):
266
        filepath = common_utils.get_asset_path('whitenoise.wav')
267
        waveform, _ = torchaudio.load(filepath, normalization=True)
268

269
270
271
272
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
273

274
275
276
        self._assert_consistency(func, waveform)

    def test_allpass(self):
277
        filepath = common_utils.get_asset_path('whitenoise.wav')
278
        waveform, _ = torchaudio.load(filepath, normalization=True)
279

280
281
282
283
284
285
286
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.allpass_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
287

288
    def test_bandpass_with_csg(self):
289
        filepath = common_utils.get_asset_path("whitenoise.wav")
290
        waveform, _ = torchaudio.load(filepath, normalization=True)
291

292
293
294
295
296
297
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
298

299
300
301
        self._assert_consistency(func, waveform)

    def test_bandpass_withou_csg(self):
302
        filepath = common_utils.get_asset_path("whitenoise.wav")
303
        waveform, _ = torchaudio.load(filepath, normalization=True)
304

305
306
307
308
309
310
311
312
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)

        self._assert_consistency(func, waveform)
313

314
    def test_bandreject(self):
315
        filepath = common_utils.get_asset_path("whitenoise.wav")
316
        waveform, _ = torchaudio.load(filepath, normalization=True)
317

318
319
320
321
322
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
323

324
325
326
        self._assert_consistency(func, waveform)

    def test_band_with_noise(self):
327
        filepath = common_utils.get_asset_path("whitenoise.wav")
328
        waveform, _ = torchaudio.load(filepath, normalization=True)
329

330
331
332
333
334
335
336
337
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = True
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
338

339
    def test_band_without_noise(self):
340
        filepath = common_utils.get_asset_path("whitenoise.wav")
341
        waveform, _ = torchaudio.load(filepath, normalization=True)
342

343
344
345
346
347
348
349
350
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = False
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
351

352
    def test_treble(self):
353
        filepath = common_utils.get_asset_path("whitenoise.wav")
354
355
356
357
358
359
360
361
362
363
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)
364
365

    def test_deemph(self):
366
        filepath = common_utils.get_asset_path("whitenoise.wav")
367
368
369
370
371
372
373
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            return F.deemph_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
374
375

    def test_riaa(self):
376
        filepath = common_utils.get_asset_path("whitenoise.wav")
377
        waveform, _ = torchaudio.load(filepath, normalization=True)
378

379
380
381
382
383
        def func(tensor):
            sample_rate = 44100
            return F.riaa_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
384

385
    def test_equalizer(self):
386
        filepath = common_utils.get_asset_path("whitenoise.wav")
387
388
389
390
391
392
393
394
395
396
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)

        self._assert_consistency(func, waveform)
397
398

    def test_perf_biquad_filtering(self):
399
        filepath = common_utils.get_asset_path("whitenoise.wav")
400
        waveform, _ = torchaudio.load(filepath, normalization=True)
401
402
403
404
405
406
407

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
408

wanglong001's avatar
wanglong001 committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    def test_sliding_window_cmn(self):
        def func(tensor):
            cmn_window = 600
            min_cmn_window = 100
            center = False
            norm_vars = False
            a = torch.tensor(
                [
                    [
                        -1.915875792503357,
                        1.147700309753418
                    ],
                    [
                        1.8242558240890503,
                        1.3869990110397339
                    ]
                ],
                device=tensor.device,
                dtype=tensor.dtype
            )
            return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
        b = torch.tensor(
            [
                [
                    -1.8701,
                    -0.1196
                ],
                [
                    1.8701,
                    0.1196
                ]
            ]
        )
        self._assert_consistency(func, b)

444
445
446
447
448
449
450
451
452
453
    def test_contrast(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            enhancement_amount = 80.
            return F.contrast(tensor, enhancement_amount)

        self._assert_consistency(func, waveform)

454
455
456
457
458
459
460
461
462
463
464
    def test_dcshift(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            shift = 0.5
            limiter_gain = 0.05
            return F.dcshift(tensor, shift, limiter_gain)

        self._assert_consistency(func, waveform)

465
466
467
468
469
470
471
472
473
474
475
    def test_overdrive(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            gain = 30.
            colour = 50.
            return F.overdrive(tensor, gain, colour)

        self._assert_consistency(func, waveform)

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    def test_phaser(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            gain_in = 0.5
            gain_out = 0.8
            delay_ms = 2.0
            decay = 0.4
            speed = 0.5
            sample_rate = 44100
            return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)

        self._assert_consistency(func, waveform)

491

492
493
494
class _TransformsTestMixin:
    """Implements test for Transforms that are performed for different devices"""
    device = None
495

496
497
    def _assert_consistency(self, transform, tensor):
        _assert_transforms_consistency(transform, tensor, self.device)
498
499
500

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
501
        self._assert_consistency(T.Spectrogram(), tensor)
502
503
504

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
505
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
506
507
508

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
509
        self._assert_consistency(T.AmplitudeToDB(), spec)
510
511
512

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
513
        self._assert_consistency(T.MelScale(), spec_f)
514
515
516

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
517
        self._assert_consistency(T.MelSpectrogram(), tensor)
518
519
520

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
521
        self._assert_consistency(T.MFCC(), tensor)
522
523
524
525
526

    def test_Resample(self):
        tensor = torch.rand((2, 1000))
        sample_rate = 100.
        sample_rate_2 = 50.
527
        self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
528
529
530

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
531
        self._assert_consistency(T.ComplexNorm(), tensor)
532
533
534

    def test_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
535
        self._assert_consistency(T.MuLawEncoding(), tensor)
536
537
538

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
539
        self._assert_consistency(T.MuLawDecoding(), tensor)
540
541
542
543
544
545

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
546
547
548
549
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
550
551

    def test_Fade(self):
552
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
553
554
555
        waveform, _ = torchaudio.load(test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000
556
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
557
558
559

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
560
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
561
562
563

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
564
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
565
566

    def test_Vol(self):
567
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
568
        waveform, _ = torchaudio.load(test_filepath)
569
570
        self._assert_consistency(T.Vol(1.1), waveform)

wanglong001's avatar
wanglong001 committed
571
572
573
574
    def test_SlidingWindowCmn(self):
        tensor = torch.rand((1000, 10))
        self._assert_consistency(T.SlidingWindowCmn(), tensor)

Artyom Astafurov's avatar
Artyom Astafurov committed
575
    def test_Vad(self):
576
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
Artyom Astafurov's avatar
Artyom Astafurov committed
577
578
579
        waveform, sample_rate = torchaudio.load(filepath)
        self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)

580

581
582
583
584
585
586
587
588
589
590
591
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
    """Test suite for Functional module on CPU"""
    device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestFunctionalCUDA(_FunctionalTestMixin, unittest.TestCase):
    """Test suite for Functional module on GPU"""
    device = torch.device('cuda')


592
593
594
595
596
597
598
599
600
class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
    """Test suite for Transforms module on CPU"""
    device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
    """Test suite for Transforms module on GPU"""
    device = torch.device('cuda')
Vincent QB's avatar
Vincent QB committed
601
602
603
604


if __name__ == '__main__':
    unittest.main()