torchscript_consistency_impl.py 20 KB
Newer Older
1
2
3
4
5
6
"""Test suites for jit-ability and its numerical compatibility"""
import unittest

import torch
import torchaudio
import torchaudio.functional as F
7
import torchaudio.transforms as T
8
9
10
11

import common_utils


12
class Functional(common_utils.TestBaseMixin):
13
14
    """Implements test for `functinoal` modul that are performed for different devices"""
    def _assert_consistency(self, func, tensor, shape_only=False):
15
        tensor = tensor.to(device=self.device, dtype=self.dtype)
16
17
18
19
20
21
22
23

        ts_func = torch.jit.script(func)
        output = func(tensor)
        ts_output = ts_func(tensor)
        if shape_only:
            ts_output = ts_output.shape
            output = output.shape
        self.assertEqual(ts_output, output)
24
25

    def test_spectrogram(self):
26
27
28
29
30
31
32
33
34
35
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            pad = 0
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

36
        tensor = torch.rand((1, 1000))
37
        self._assert_consistency(func, tensor)
38
39

    def test_griffinlim(self):
40
41
42
43
44
45
46
47
48
49
50
51
52
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            momentum = 0.99
            n_iter = 32
            length = 1000
            rand_int = False
            return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)

53
        tensor = torch.rand((1, 201, 6))
54
        self._assert_consistency(func, tensor)
55
56

    def test_compute_deltas(self):
57
58
59
60
        def func(tensor):
            win_length = 2 * 7 + 1
            return F.compute_deltas(tensor, win_length=win_length)

61
62
63
        channel = 13
        n_mfcc = channel * 3
        time = 1021
64
65
        tensor = torch.randn(channel, n_mfcc, time)
        self._assert_consistency(func, tensor)
66
67

    def test_detect_pitch_frequency(self):
68
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
69
70
71
72
73
74
75
        waveform, _ = torchaudio.load(filepath)

        def func(tensor):
            sample_rate = 44100
            return F.detect_pitch_frequency(tensor, sample_rate)

        self._assert_consistency(func, waveform)
76
77

    def test_create_fb_matrix(self):
78
79
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')
80

81
82
83
84
85
86
        def func(_):
            n_stft = 100
            f_min = 0.0
            f_max = 20.0
            n_mels = 10
            sample_rate = 16000
Vincent QB's avatar
Vincent QB committed
87
88
            norm = ""
            return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm)
89
90
91

        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
92
93

    def test_amplitude_to_DB(self):
94
95
96
97
98
99
        def func(tensor):
            multiplier = 10.0
            amin = 1e-10
            db_multiplier = 0.0
            top_db = 80.0
            return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
100

101
102
        tensor = torch.rand((6, 201))
        self._assert_consistency(func, tensor)
103
104

    def test_DB_to_amplitude(self):
105
106
107
108
        def func(tensor):
            ref = 1.
            power = 1.
            return F.DB_to_amplitude(tensor, ref, power)
109

110
111
        tensor = torch.rand((1, 100))
        self._assert_consistency(func, tensor)
112
113

    def test_create_dct(self):
114
115
116
117
118
119
120
121
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')

        def func(_):
            n_mfcc = 40
            n_mels = 128
            norm = "ortho"
            return F.create_dct(n_mfcc, n_mels, norm)
122

123
124
        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
125
126

    def test_mu_law_encoding(self):
127
128
129
        def func(tensor):
            qc = 256
            return F.mu_law_encoding(tensor, qc)
130

131
132
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
133
134

    def test_mu_law_decoding(self):
135
136
137
        def func(tensor):
            qc = 256
            return F.mu_law_decoding(tensor, qc)
138

139
140
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
141
142

    def test_complex_norm(self):
143
144
145
        def func(tensor):
            power = 2.
            return F.complex_norm(tensor, power)
146

147
        tensor = torch.randn(1, 2, 1025, 400, 2)
148
        self._assert_consistency(func, tensor)
149
150

    def test_mask_along_axis(self):
151
152
153
154
155
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis(tensor, mask_param, mask_value, axis)
156

157
158
        tensor = torch.randn(2, 1025, 400)
        self._assert_consistency(func, tensor)
159
160

    def test_mask_along_axis_iid(self):
161
162
163
164
165
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
166

167
168
        tensor = torch.randn(4, 2, 1025, 400)
        self._assert_consistency(func, tensor)
169
170

    def test_gain(self):
171
172
173
174
        def func(tensor):
            gainDB = 2.0
            return F.gain(tensor, gainDB)

175
        tensor = torch.rand((1, 1000))
176
177
178
179
180
181
182
183
        self._assert_consistency(func, tensor)

    def test_dither_TPDF(self):
        def func(tensor):
            return F.dither(tensor, 'TPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
184

185
186
187
    def test_dither_RPDF(self):
        def func(tensor):
            return F.dither(tensor, 'RPDF')
188
189

        tensor = torch.rand((2, 1000))
190
        self._assert_consistency(func, tensor, shape_only=True)
191

192
193
194
195
196
197
    def test_dither_GPDF(self):
        def func(tensor):
            return F.dither(tensor, 'GPDF')

        tensor = torch.rand((2, 1000))
        self._assert_consistency(func, tensor, shape_only=True)
198

199
    def test_lfilter(self):
200
        if self.dtype == torch.float64:
201
            raise unittest.SkipTest("This test is known to fail for float64")
202

203
        filepath = common_utils.get_asset_path('whitenoise.wav')
204
205
        waveform, _ = torchaudio.load(filepath, normalization=True)

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
242
243

    def test_lowpass(self):
244
        if self.dtype == torch.float64:
245
            raise unittest.SkipTest("This test is known to fail for float64")
246

247
        filepath = common_utils.get_asset_path('whitenoise.wav')
248
        waveform, _ = torchaudio.load(filepath, normalization=True)
249

250
251
252
253
254
255
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 3000.
            return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
256

257
    def test_highpass(self):
258
        if self.dtype == torch.float64:
259
            raise unittest.SkipTest("This test is known to fail for float64")
260

261
        filepath = common_utils.get_asset_path('whitenoise.wav')
262
        waveform, _ = torchaudio.load(filepath, normalization=True)
263

264
265
266
267
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
268

269
270
271
        self._assert_consistency(func, waveform)

    def test_allpass(self):
272
        if self.dtype == torch.float64:
273
            raise unittest.SkipTest("This test is known to fail for float64")
274

275
        filepath = common_utils.get_asset_path('whitenoise.wav')
276
        waveform, _ = torchaudio.load(filepath, normalization=True)
277

278
279
280
281
282
283
284
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.allpass_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
285

286
    def test_bandpass_with_csg(self):
287
        if self.dtype == torch.float64:
288
            raise unittest.SkipTest("This test is known to fail for float64")
289

290
        filepath = common_utils.get_asset_path("whitenoise.wav")
291
        waveform, _ = torchaudio.load(filepath, normalization=True)
292

293
294
295
296
297
298
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
299

300
301
        self._assert_consistency(func, waveform)

302
303
    def test_bandpass_without_csg(self):
        if self.dtype == torch.float64:
304
            raise unittest.SkipTest("This test is known to fail for float64")
305

306
        filepath = common_utils.get_asset_path("whitenoise.wav")
307
        waveform, _ = torchaudio.load(filepath, normalization=True)
308

309
310
311
312
313
314
315
316
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)

        self._assert_consistency(func, waveform)
317

318
    def test_bandreject(self):
319
        if self.dtype == torch.float64:
320
            raise unittest.SkipTest("This test is known to fail for float64")
321

322
        filepath = common_utils.get_asset_path("whitenoise.wav")
323
        waveform, _ = torchaudio.load(filepath, normalization=True)
324

325
326
327
328
329
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
330

331
332
333
        self._assert_consistency(func, waveform)

    def test_band_with_noise(self):
334
        if self.dtype == torch.float64:
335
            raise unittest.SkipTest("This test is known to fail for float64")
336

337
        filepath = common_utils.get_asset_path("whitenoise.wav")
338
        waveform, _ = torchaudio.load(filepath, normalization=True)
339

340
341
342
343
344
345
346
347
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = True
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
348

349
    def test_band_without_noise(self):
350
        if self.dtype == torch.float64:
351
            raise unittest.SkipTest("This test is known to fail for float64")
352

353
        filepath = common_utils.get_asset_path("whitenoise.wav")
354
        waveform, _ = torchaudio.load(filepath, normalization=True)
355

356
357
358
359
360
361
362
363
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = False
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
364

365
    def test_treble(self):
366
        if self.dtype == torch.float64:
367
            raise unittest.SkipTest("This test is known to fail for float64")
368

369
        filepath = common_utils.get_asset_path("whitenoise.wav")
370
371
372
373
374
375
376
377
378
379
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)
380
381

    def test_deemph(self):
382
        if self.dtype == torch.float64:
383
            raise unittest.SkipTest("This test is known to fail for float64")
384

385
        filepath = common_utils.get_asset_path("whitenoise.wav")
386
387
388
389
390
391
392
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            return F.deemph_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
393
394

    def test_riaa(self):
395
        if self.dtype == torch.float64:
396
            raise unittest.SkipTest("This test is known to fail for float64")
397

398
        filepath = common_utils.get_asset_path("whitenoise.wav")
399
        waveform, _ = torchaudio.load(filepath, normalization=True)
400

401
402
403
404
405
        def func(tensor):
            sample_rate = 44100
            return F.riaa_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
406

407
    def test_equalizer(self):
408
        if self.dtype == torch.float64:
409
            raise unittest.SkipTest("This test is known to fail for float64")
410

411
        filepath = common_utils.get_asset_path("whitenoise.wav")
412
413
414
415
416
417
418
419
420
421
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)

        self._assert_consistency(func, waveform)
422
423

    def test_perf_biquad_filtering(self):
424
        if self.dtype == torch.float64:
425
            raise unittest.SkipTest("This test is known to fail for float64")
426

427
        filepath = common_utils.get_asset_path("whitenoise.wav")
428
        waveform, _ = torchaudio.load(filepath, normalization=True)
429
430
431
432
433
434
435

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
436

wanglong001's avatar
wanglong001 committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    def test_sliding_window_cmn(self):
        def func(tensor):
            cmn_window = 600
            min_cmn_window = 100
            center = False
            norm_vars = False
            a = torch.tensor(
                [
                    [
                        -1.915875792503357,
                        1.147700309753418
                    ],
                    [
                        1.8242558240890503,
                        1.3869990110397339
                    ]
                ],
                device=tensor.device,
                dtype=tensor.dtype
            )
            return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
        b = torch.tensor(
            [
                [
                    -1.8701,
                    -0.1196
                ],
                [
                    1.8701,
                    0.1196
                ]
            ]
        )
        self._assert_consistency(func, b)

472
473
474
475
476
477
478
479
480
481
    def test_contrast(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            enhancement_amount = 80.
            return F.contrast(tensor, enhancement_amount)

        self._assert_consistency(func, waveform)

482
483
484
485
486
487
488
489
490
491
492
    def test_dcshift(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            shift = 0.5
            limiter_gain = 0.05
            return F.dcshift(tensor, shift, limiter_gain)

        self._assert_consistency(func, waveform)

493
494
495
496
497
498
499
500
501
502
503
    def test_overdrive(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            gain = 30.
            colour = 50.
            return F.overdrive(tensor, gain, colour)

        self._assert_consistency(func, waveform)

504
505
    def test_phaser(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
506
        waveform, _ = torchaudio.load(filepath, normalization=True)
507
508
509
510
511
512
513
514
515
516
517
518

        def func(tensor):
            gain_in = 0.5
            gain_out = 0.8
            delay_ms = 2.0
            decay = 0.4
            speed = 0.5
            sample_rate = 44100
            return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)

        self._assert_consistency(func, waveform)

519

520
class Transforms(common_utils.TestBaseMixin):
521
522
    """Implements test for Transforms that are performed for different devices"""
    def _assert_consistency(self, transform, tensor):
523
524
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        transform = transform.to(device=self.device, dtype=self.dtype)
525
526
527
528
529

        ts_transform = torch.jit.script(transform)
        output = transform(tensor)
        ts_output = ts_transform(tensor)
        self.assertEqual(ts_output, output)
530
531
532

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
533
        self._assert_consistency(T.Spectrogram(), tensor)
534
535
536

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
537
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
538
539
540

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
541
        self._assert_consistency(T.AmplitudeToDB(), spec)
542
543
544

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
545
        self._assert_consistency(T.MelScale(), spec_f)
546
547
548

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
549
        self._assert_consistency(T.MelSpectrogram(), tensor)
550
551
552

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
553
        self._assert_consistency(T.MFCC(), tensor)
554
555
556
557
558

    def test_Resample(self):
        tensor = torch.rand((2, 1000))
        sample_rate = 100.
        sample_rate_2 = 50.
559
        self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
560
561
562

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
563
        self._assert_consistency(T.ComplexNorm(), tensor)
564
565
566

    def test_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
567
        self._assert_consistency(T.MuLawEncoding(), tensor)
568
569
570

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
571
        self._assert_consistency(T.MuLawDecoding(), tensor)
572
573
574
575
576
577

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
578
579
580
581
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
582
583

    def test_Fade(self):
584
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
585
586
587
        waveform, _ = torchaudio.load(test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000
588
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
589
590
591

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
592
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
593
594
595

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
596
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
597
598

    def test_Vol(self):
599
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
600
        waveform, _ = torchaudio.load(test_filepath)
601
602
        self._assert_consistency(T.Vol(1.1), waveform)

wanglong001's avatar
wanglong001 committed
603
604
605
606
    def test_SlidingWindowCmn(self):
        tensor = torch.rand((1000, 10))
        self._assert_consistency(T.SlidingWindowCmn(), tensor)

Artyom Astafurov's avatar
Artyom Astafurov committed
607
    def test_Vad(self):
608
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
Artyom Astafurov's avatar
Artyom Astafurov committed
609
610
        waveform, sample_rate = torchaudio.load(filepath)
        self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)