test_torchscript_consistency.py 20.2 KB
Newer Older
1
2
"""Test suites for jit-ability and its numerical compatibility"""
import unittest
3
import pytest
4
5
6
7

import torch
import torchaudio
import torchaudio.functional as F
8
import torchaudio.transforms as T
9
10
11
12

import common_utils


13
def _assert_functional_consistency(func, tensor, shape_only=False):
14
15
16
17
18
19
20
21
22
23
    ts_func = torch.jit.script(func)
    output = func(tensor)
    ts_output = ts_func(tensor)

    if shape_only:
        assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
    else:
        torch.testing.assert_allclose(ts_output, output)


24
def _assert_transforms_consistency(transform, tensor):
25
26
27
28
29
30
    ts_transform = torch.jit.script(transform)
    output = transform(tensor)
    ts_output = ts_transform(tensor)
    torch.testing.assert_allclose(ts_output, output)


31
class Functional(common_utils.TestBaseMixin):
32
33
    """Implements test for `functinoal` modul that are performed for different devices"""
    def _assert_consistency(self, func, tensor, shape_only=False):
34
35
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        return _assert_functional_consistency(func, tensor, shape_only=shape_only)
36
37

    def test_spectrogram(self):
38
39
40
41
42
43
44
45
46
47
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            pad = 0
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

48
        tensor = torch.rand((1, 1000))
49
        self._assert_consistency(func, tensor)
50
51

    def test_griffinlim(self):
52
53
54
55
56
57
58
59
60
61
62
63
64
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            momentum = 0.99
            n_iter = 32
            length = 1000
            rand_int = False
            return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)

65
        tensor = torch.rand((1, 201, 6))
66
        self._assert_consistency(func, tensor)
67
68

    def test_compute_deltas(self):
69
70
71
72
        def func(tensor):
            win_length = 2 * 7 + 1
            return F.compute_deltas(tensor, win_length=win_length)

73
74
75
        channel = 13
        n_mfcc = channel * 3
        time = 1021
76
77
        tensor = torch.randn(channel, n_mfcc, time)
        self._assert_consistency(func, tensor)
78
79

    def test_detect_pitch_frequency(self):
80
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
81
82
83
84
85
86
87
        waveform, _ = torchaudio.load(filepath)

        def func(tensor):
            sample_rate = 44100
            return F.detect_pitch_frequency(tensor, sample_rate)

        self._assert_consistency(func, waveform)
88
89

    def test_create_fb_matrix(self):
90
91
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')
92

93
94
95
96
97
98
99
100
101
102
        def func(_):
            n_stft = 100
            f_min = 0.0
            f_max = 20.0
            n_mels = 10
            sample_rate = 16000
            return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)

        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
103
104

    def test_amplitude_to_DB(self):
105
106
107
108
109
110
        def func(tensor):
            multiplier = 10.0
            amin = 1e-10
            db_multiplier = 0.0
            top_db = 80.0
            return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
111

112
113
        tensor = torch.rand((6, 201))
        self._assert_consistency(func, tensor)
114
115

    def test_DB_to_amplitude(self):
116
117
118
119
        def func(tensor):
            ref = 1.
            power = 1.
            return F.DB_to_amplitude(tensor, ref, power)
120

121
122
        tensor = torch.rand((1, 100))
        self._assert_consistency(func, tensor)
123
124

    def test_create_dct(self):
125
126
127
128
129
130
131
132
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')

        def func(_):
            n_mfcc = 40
            n_mels = 128
            norm = "ortho"
            return F.create_dct(n_mfcc, n_mels, norm)
133

134
135
        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
136
137

    def test_mu_law_encoding(self):
138
139
140
        def func(tensor):
            qc = 256
            return F.mu_law_encoding(tensor, qc)
141

142
143
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
144
145

    def test_mu_law_decoding(self):
146
147
148
        def func(tensor):
            qc = 256
            return F.mu_law_decoding(tensor, qc)
149

150
151
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
152
153

    def test_complex_norm(self):
154
155
156
        def func(tensor):
            power = 2.
            return F.complex_norm(tensor, power)
157

158
        tensor = torch.randn(1, 2, 1025, 400, 2)
159
        self._assert_consistency(func, tensor)
160
161

    def test_mask_along_axis(self):
162
163
164
165
166
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis(tensor, mask_param, mask_value, axis)
167

168
169
        tensor = torch.randn(2, 1025, 400)
        self._assert_consistency(func, tensor)
170
171

    def test_mask_along_axis_iid(self):
172
173
174
175
176
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
177

178
179
        tensor = torch.randn(4, 2, 1025, 400)
        self._assert_consistency(func, tensor)
180
181

    def test_gain(self):
182
183
184
185
        def func(tensor):
            gainDB = 2.0
            return F.gain(tensor, gainDB)

186
        tensor = torch.rand((1, 1000))
187
188
189
190
191
192
193
194
        self._assert_consistency(func, tensor)

    def test_dither_TPDF(self):
        def func(tensor):
            return F.dither(tensor, 'TPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
195

196
197
198
    def test_dither_RPDF(self):
        def func(tensor):
            return F.dither(tensor, 'RPDF')
199
200

        tensor = torch.rand((2, 1000))
201
        self._assert_consistency(func, tensor, shape_only=True)
202

203
204
205
206
207
208
    def test_dither_GPDF(self):
        def func(tensor):
            return F.dither(tensor, 'GPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
209

210
    def test_lfilter(self):
211
212
213
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

214
        filepath = common_utils.get_asset_path('whitenoise.wav')
215
216
        waveform, _ = torchaudio.load(filepath, normalization=True)

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
253
254

    def test_lowpass(self):
255
256
257
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

258
        filepath = common_utils.get_asset_path('whitenoise.wav')
259
        waveform, _ = torchaudio.load(filepath, normalization=True)
260

261
262
263
264
265
266
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 3000.
            return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
267

268
    def test_highpass(self):
269
270
271
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

272
        filepath = common_utils.get_asset_path('whitenoise.wav')
273
        waveform, _ = torchaudio.load(filepath, normalization=True)
274

275
276
277
278
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
279

280
281
282
        self._assert_consistency(func, waveform)

    def test_allpass(self):
283
284
285
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

286
        filepath = common_utils.get_asset_path('whitenoise.wav')
287
        waveform, _ = torchaudio.load(filepath, normalization=True)
288

289
290
291
292
293
294
295
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.allpass_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
296

297
    def test_bandpass_with_csg(self):
298
299
300
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

301
        filepath = common_utils.get_asset_path("whitenoise.wav")
302
        waveform, _ = torchaudio.load(filepath, normalization=True)
303

304
305
306
307
308
309
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
310

311
312
        self._assert_consistency(func, waveform)

313
314
315
316
    def test_bandpass_without_csg(self):
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

317
        filepath = common_utils.get_asset_path("whitenoise.wav")
318
        waveform, _ = torchaudio.load(filepath, normalization=True)
319

320
321
322
323
324
325
326
327
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)

        self._assert_consistency(func, waveform)
328

329
    def test_bandreject(self):
330
331
332
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

333
        filepath = common_utils.get_asset_path("whitenoise.wav")
334
        waveform, _ = torchaudio.load(filepath, normalization=True)
335

336
337
338
339
340
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
341

342
343
344
        self._assert_consistency(func, waveform)

    def test_band_with_noise(self):
345
346
347
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

348
        filepath = common_utils.get_asset_path("whitenoise.wav")
349
        waveform, _ = torchaudio.load(filepath, normalization=True)
350

351
352
353
354
355
356
357
358
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = True
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
359

360
    def test_band_without_noise(self):
361
362
363
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

364
        filepath = common_utils.get_asset_path("whitenoise.wav")
365
        waveform, _ = torchaudio.load(filepath, normalization=True)
366

367
368
369
370
371
372
373
374
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = False
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
375

376
    def test_treble(self):
377
378
379
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

380
        filepath = common_utils.get_asset_path("whitenoise.wav")
381
382
383
384
385
386
387
388
389
390
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)
391
392

    def test_deemph(self):
393
394
395
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

396
        filepath = common_utils.get_asset_path("whitenoise.wav")
397
398
399
400
401
402
403
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            return F.deemph_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
404
405

    def test_riaa(self):
406
407
408
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

409
        filepath = common_utils.get_asset_path("whitenoise.wav")
410
        waveform, _ = torchaudio.load(filepath, normalization=True)
411

412
413
414
415
416
        def func(tensor):
            sample_rate = 44100
            return F.riaa_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
417

418
    def test_equalizer(self):
419
420
421
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

422
        filepath = common_utils.get_asset_path("whitenoise.wav")
423
424
425
426
427
428
429
430
431
432
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)

        self._assert_consistency(func, waveform)
433
434

    def test_perf_biquad_filtering(self):
435
436
437
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

438
        filepath = common_utils.get_asset_path("whitenoise.wav")
439
        waveform, _ = torchaudio.load(filepath, normalization=True)
440
441
442
443
444
445
446

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
447

wanglong001's avatar
wanglong001 committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    def test_sliding_window_cmn(self):
        def func(tensor):
            cmn_window = 600
            min_cmn_window = 100
            center = False
            norm_vars = False
            a = torch.tensor(
                [
                    [
                        -1.915875792503357,
                        1.147700309753418
                    ],
                    [
                        1.8242558240890503,
                        1.3869990110397339
                    ]
                ],
                device=tensor.device,
                dtype=tensor.dtype
            )
            return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
        b = torch.tensor(
            [
                [
                    -1.8701,
                    -0.1196
                ],
                [
                    1.8701,
                    0.1196
                ]
            ]
        )
        self._assert_consistency(func, b)

483
484
485
486
487
488
489
490
491
492
    def test_contrast(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            enhancement_amount = 80.
            return F.contrast(tensor, enhancement_amount)

        self._assert_consistency(func, waveform)

493
494
495
496
497
498
499
500
501
502
503
    def test_dcshift(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            shift = 0.5
            limiter_gain = 0.05
            return F.dcshift(tensor, shift, limiter_gain)

        self._assert_consistency(func, waveform)

504
505
506
507
508
509
510
511
512
513
514
    def test_overdrive(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            gain = 30.
            colour = 50.
            return F.overdrive(tensor, gain, colour)

        self._assert_consistency(func, waveform)

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    def test_phaser(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            gain_in = 0.5
            gain_out = 0.8
            delay_ms = 2.0
            decay = 0.4
            speed = 0.5
            sample_rate = 44100
            return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)

        self._assert_consistency(func, waveform)

530

531
class Transforms(common_utils.TestBaseMixin):
532
533
    """Implements test for Transforms that are performed for different devices"""
    def _assert_consistency(self, transform, tensor):
534
535
536
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        transform = transform.to(device=self.device, dtype=self.dtype)
        _assert_transforms_consistency(transform, tensor)
537
538
539

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
540
        self._assert_consistency(T.Spectrogram(), tensor)
541
542
543

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
544
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
545
546
547

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
548
        self._assert_consistency(T.AmplitudeToDB(), spec)
549
550
551

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
552
        self._assert_consistency(T.MelScale(), spec_f)
553
554
555

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
556
        self._assert_consistency(T.MelSpectrogram(), tensor)
557
558
559

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
560
        self._assert_consistency(T.MFCC(), tensor)
561
562
563
564
565

    def test_Resample(self):
        tensor = torch.rand((2, 1000))
        sample_rate = 100.
        sample_rate_2 = 50.
566
        self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
567
568
569

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
570
        self._assert_consistency(T.ComplexNorm(), tensor)
571
572
573

    def test_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
574
        self._assert_consistency(T.MuLawEncoding(), tensor)
575
576
577

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
578
        self._assert_consistency(T.MuLawDecoding(), tensor)
579
580
581
582
583
584

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
585
586
587
588
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
589
590

    def test_Fade(self):
591
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
592
593
594
        waveform, _ = torchaudio.load(test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000
595
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
596
597
598

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
599
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
600
601
602

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
603
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
604
605

    def test_Vol(self):
606
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
607
        waveform, _ = torchaudio.load(test_filepath)
608
609
        self._assert_consistency(T.Vol(1.1), waveform)

wanglong001's avatar
wanglong001 committed
610
611
612
613
    def test_SlidingWindowCmn(self):
        tensor = torch.rand((1000, 10))
        self._assert_consistency(T.SlidingWindowCmn(), tensor)

Artyom Astafurov's avatar
Artyom Astafurov committed
614
    def test_Vad(self):
615
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
Artyom Astafurov's avatar
Artyom Astafurov committed
616
617
618
        waveform, sample_rate = torchaudio.load(filepath)
        self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)

619

620
common_utils.define_test_suites(globals(), [Functional, Transforms])