torchscript_consistency_impl.py 19.1 KB
Newer Older
1
2
3
4
5
6
"""Test suites for jit-ability and its numerical compatibility"""
import unittest

import torch
import torchaudio
import torchaudio.functional as F
7
import torchaudio.transforms as T
8

9
from . import common_utils
10
11


12
class Functional(common_utils.TestBaseMixin):
13
14
    """Implements test for `functinoal` modul that are performed for different devices"""
    def _assert_consistency(self, func, tensor, shape_only=False):
15
        tensor = tensor.to(device=self.device, dtype=self.dtype)
16
17
18
19
20
21
22
23

        ts_func = torch.jit.script(func)
        output = func(tensor)
        ts_output = ts_func(tensor)
        if shape_only:
            ts_output = ts_output.shape
            output = output.shape
        self.assertEqual(ts_output, output)
24
25

    def test_spectrogram(self):
26
27
28
29
30
31
32
33
34
35
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            pad = 0
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

moto's avatar
moto committed
36
        tensor = common_utils.get_whitenoise()
37
        self._assert_consistency(func, tensor)
38
39

    def test_griffinlim(self):
40
41
42
43
44
45
46
47
48
49
50
51
52
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            momentum = 0.99
            n_iter = 32
            length = 1000
            rand_int = False
            return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)

53
        tensor = torch.rand((1, 201, 6))
54
        self._assert_consistency(func, tensor)
55
56

    def test_compute_deltas(self):
57
58
59
60
        def func(tensor):
            win_length = 2 * 7 + 1
            return F.compute_deltas(tensor, win_length=win_length)

61
62
63
        channel = 13
        n_mfcc = channel * 3
        time = 1021
64
65
        tensor = torch.randn(channel, n_mfcc, time)
        self._assert_consistency(func, tensor)
66
67

    def test_detect_pitch_frequency(self):
moto's avatar
moto committed
68
        waveform = common_utils.get_sinusoid(sample_rate=44100)
69
70
71
72
73
74

        def func(tensor):
            sample_rate = 44100
            return F.detect_pitch_frequency(tensor, sample_rate)

        self._assert_consistency(func, waveform)
75
76

    def test_create_fb_matrix(self):
77
78
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')
79

80
81
82
83
84
85
        def func(_):
            n_stft = 100
            f_min = 0.0
            f_max = 20.0
            n_mels = 10
            sample_rate = 16000
86
            norm = "slaney"
Vincent QB's avatar
Vincent QB committed
87
            return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm)
88
89
90

        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
91
92

    def test_amplitude_to_DB(self):
93
94
95
96
97
98
        def func(tensor):
            multiplier = 10.0
            amin = 1e-10
            db_multiplier = 0.0
            top_db = 80.0
            return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
99

100
101
        tensor = torch.rand((6, 201))
        self._assert_consistency(func, tensor)
102
103

    def test_DB_to_amplitude(self):
104
105
106
107
        def func(tensor):
            ref = 1.
            power = 1.
            return F.DB_to_amplitude(tensor, ref, power)
108

109
110
        tensor = torch.rand((1, 100))
        self._assert_consistency(func, tensor)
111
112

    def test_create_dct(self):
113
114
115
116
117
118
119
120
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')

        def func(_):
            n_mfcc = 40
            n_mels = 128
            norm = "ortho"
            return F.create_dct(n_mfcc, n_mels, norm)
121

122
123
        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
124
125

    def test_mu_law_encoding(self):
126
127
128
        def func(tensor):
            qc = 256
            return F.mu_law_encoding(tensor, qc)
129

moto's avatar
moto committed
130
131
        waveform = common_utils.get_whitenoise()
        self._assert_consistency(func, waveform)
132
133

    def test_mu_law_decoding(self):
134
135
136
        def func(tensor):
            qc = 256
            return F.mu_law_decoding(tensor, qc)
137

138
139
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
140
141

    def test_complex_norm(self):
142
143
144
        def func(tensor):
            power = 2.
            return F.complex_norm(tensor, power)
145

146
        tensor = torch.randn(1, 2, 1025, 400, 2)
147
        self._assert_consistency(func, tensor)
148
149

    def test_mask_along_axis(self):
150
151
152
153
154
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis(tensor, mask_param, mask_value, axis)
155

156
157
        tensor = torch.randn(2, 1025, 400)
        self._assert_consistency(func, tensor)
158
159

    def test_mask_along_axis_iid(self):
160
161
162
163
164
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
165

166
167
        tensor = torch.randn(4, 2, 1025, 400)
        self._assert_consistency(func, tensor)
168
169

    def test_gain(self):
170
171
172
173
        def func(tensor):
            gainDB = 2.0
            return F.gain(tensor, gainDB)

174
        tensor = torch.rand((1, 1000))
175
176
177
178
179
180
        self._assert_consistency(func, tensor)

    def test_dither_TPDF(self):
        def func(tensor):
            return F.dither(tensor, 'TPDF')

moto's avatar
moto committed
181
        tensor = common_utils.get_whitenoise(n_channels=2)
182
        self._assert_consistency(func, tensor, shape_only=True)
183

184
185
186
    def test_dither_RPDF(self):
        def func(tensor):
            return F.dither(tensor, 'RPDF')
187

moto's avatar
moto committed
188
        tensor = common_utils.get_whitenoise(n_channels=2)
189
        self._assert_consistency(func, tensor, shape_only=True)
190

191
192
193
194
    def test_dither_GPDF(self):
        def func(tensor):
            return F.dither(tensor, 'GPDF')

moto's avatar
moto committed
195
        tensor = common_utils.get_whitenoise(n_channels=2)
196
        self._assert_consistency(func, tensor, shape_only=True)
197

198
    def test_lfilter(self):
199
        if self.dtype == torch.float64:
200
            raise unittest.SkipTest("This test is known to fail for float64")
201

moto's avatar
moto committed
202
        waveform = common_utils.get_whitenoise()
203

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
240
241

    def test_lowpass(self):
242
        if self.dtype == torch.float64:
243
            raise unittest.SkipTest("This test is known to fail for float64")
244

moto's avatar
moto committed
245
        waveform = common_utils.get_whitenoise(sample_rate=44100)
246

247
248
249
250
251
252
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 3000.
            return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
253

254
    def test_highpass(self):
255
        if self.dtype == torch.float64:
256
            raise unittest.SkipTest("This test is known to fail for float64")
257

moto's avatar
moto committed
258
        waveform = common_utils.get_whitenoise(sample_rate=44100)
259

260
261
262
263
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
264

265
266
267
        self._assert_consistency(func, waveform)

    def test_allpass(self):
268
        if self.dtype == torch.float64:
269
            raise unittest.SkipTest("This test is known to fail for float64")
270

moto's avatar
moto committed
271
        waveform = common_utils.get_whitenoise(sample_rate=44100)
272

273
274
275
276
277
278
279
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.allpass_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
280

281
    def test_bandpass_with_csg(self):
282
        if self.dtype == torch.float64:
283
            raise unittest.SkipTest("This test is known to fail for float64")
284

moto's avatar
moto committed
285
        waveform = common_utils.get_whitenoise(sample_rate=44100)
286

287
288
289
290
291
292
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
293

294
295
        self._assert_consistency(func, waveform)

296
297
    def test_bandpass_without_csg(self):
        if self.dtype == torch.float64:
298
            raise unittest.SkipTest("This test is known to fail for float64")
299

moto's avatar
moto committed
300
        waveform = common_utils.get_whitenoise(sample_rate=44100)
301

302
303
304
305
306
307
308
309
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)

        self._assert_consistency(func, waveform)
310

311
    def test_bandreject(self):
312
        if self.dtype == torch.float64:
313
            raise unittest.SkipTest("This test is known to fail for float64")
314

moto's avatar
moto committed
315
        waveform = common_utils.get_whitenoise(sample_rate=44100)
316

317
318
319
320
321
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
322

323
324
325
        self._assert_consistency(func, waveform)

    def test_band_with_noise(self):
326
        if self.dtype == torch.float64:
327
            raise unittest.SkipTest("This test is known to fail for float64")
328

moto's avatar
moto committed
329
        waveform = common_utils.get_whitenoise(sample_rate=44100)
330

331
332
333
334
335
336
337
338
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = True
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
339

340
    def test_band_without_noise(self):
341
        if self.dtype == torch.float64:
342
            raise unittest.SkipTest("This test is known to fail for float64")
343

moto's avatar
moto committed
344
        waveform = common_utils.get_whitenoise(sample_rate=44100)
345

346
347
348
349
350
351
352
353
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = False
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
354

355
    def test_treble(self):
356
        if self.dtype == torch.float64:
357
            raise unittest.SkipTest("This test is known to fail for float64")
358

moto's avatar
moto committed
359
        waveform = common_utils.get_whitenoise(sample_rate=44100)
360
361
362
363
364
365
366
367
368

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)
369
370

    def test_deemph(self):
371
        if self.dtype == torch.float64:
372
            raise unittest.SkipTest("This test is known to fail for float64")
373

moto's avatar
moto committed
374
        waveform = common_utils.get_whitenoise(sample_rate=44100)
375
376
377
378
379
380

        def func(tensor):
            sample_rate = 44100
            return F.deemph_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
381
382

    def test_riaa(self):
383
        if self.dtype == torch.float64:
384
            raise unittest.SkipTest("This test is known to fail for float64")
385

moto's avatar
moto committed
386
        waveform = common_utils.get_whitenoise(sample_rate=44100)
387

388
389
390
391
392
        def func(tensor):
            sample_rate = 44100
            return F.riaa_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
393

394
    def test_equalizer(self):
395
        if self.dtype == torch.float64:
396
            raise unittest.SkipTest("This test is known to fail for float64")
397

moto's avatar
moto committed
398
        waveform = common_utils.get_whitenoise(sample_rate=44100)
399
400
401
402
403
404
405
406
407

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)

        self._assert_consistency(func, waveform)
408
409

    def test_perf_biquad_filtering(self):
410
        if self.dtype == torch.float64:
411
            raise unittest.SkipTest("This test is known to fail for float64")
412

moto's avatar
moto committed
413
        waveform = common_utils.get_whitenoise()
414
415
416
417
418
419
420

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
421

wanglong001's avatar
wanglong001 committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
    def test_sliding_window_cmn(self):
        def func(tensor):
            cmn_window = 600
            min_cmn_window = 100
            center = False
            norm_vars = False
            a = torch.tensor(
                [
                    [
                        -1.915875792503357,
                        1.147700309753418
                    ],
                    [
                        1.8242558240890503,
                        1.3869990110397339
                    ]
                ],
                device=tensor.device,
                dtype=tensor.dtype
            )
            return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
        b = torch.tensor(
            [
                [
                    -1.8701,
                    -0.1196
                ],
                [
                    1.8701,
                    0.1196
                ]
            ]
        )
        self._assert_consistency(func, b)

457
    def test_contrast(self):
moto's avatar
moto committed
458
        waveform = common_utils.get_whitenoise()
459
460
461
462
463
464
465

        def func(tensor):
            enhancement_amount = 80.
            return F.contrast(tensor, enhancement_amount)

        self._assert_consistency(func, waveform)

466
    def test_dcshift(self):
moto's avatar
moto committed
467
        waveform = common_utils.get_whitenoise()
468
469
470
471
472
473
474
475

        def func(tensor):
            shift = 0.5
            limiter_gain = 0.05
            return F.dcshift(tensor, shift, limiter_gain)

        self._assert_consistency(func, waveform)

476
    def test_overdrive(self):
moto's avatar
moto committed
477
        waveform = common_utils.get_whitenoise()
478
479
480
481
482
483
484
485

        def func(tensor):
            gain = 30.
            colour = 50.
            return F.overdrive(tensor, gain, colour)

        self._assert_consistency(func, waveform)

486
    def test_phaser(self):
moto's avatar
moto committed
487
        waveform = common_utils.get_whitenoise(sample_rate=44100)
488
489
490
491
492
493
494
495
496
497
498
499

        def func(tensor):
            gain_in = 0.5
            gain_out = 0.8
            delay_ms = 2.0
            decay = 0.4
            speed = 0.5
            sample_rate = 44100
            return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)

        self._assert_consistency(func, waveform)

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    def test_flanger(self):
        torch.random.manual_seed(40)
        waveform = torch.rand(2, 100) - 0.5

        def func(tensor):
            delay = 0.8
            depth = 0.88
            regen = 3.0
            width = 0.23
            speed = 1.3
            phase = 60.
            sample_rate = 44100
            return F.flanger(tensor, sample_rate, delay, depth, regen, width, speed,
                             phase, modulation='sinusoidal', interpolation='linear')

        self._assert_consistency(func, waveform)

517

518
class Transforms(common_utils.TestBaseMixin):
519
520
    """Implements test for Transforms that are performed for different devices"""
    def _assert_consistency(self, transform, tensor):
521
522
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        transform = transform.to(device=self.device, dtype=self.dtype)
523
524
525
526
527

        ts_transform = torch.jit.script(transform)
        output = transform(tensor)
        ts_output = ts_transform(tensor)
        self.assertEqual(ts_output, output)
528
529
530

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
531
        self._assert_consistency(T.Spectrogram(), tensor)
532
533
534

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
535
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
536
537
538

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
539
        self._assert_consistency(T.AmplitudeToDB(), spec)
540
541
542

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
543
        self._assert_consistency(T.MelScale(), spec_f)
544
545
546

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
547
        self._assert_consistency(T.MelSpectrogram(), tensor)
548
549
550

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
551
        self._assert_consistency(T.MFCC(), tensor)
552
553

    def test_Resample(self):
moto's avatar
moto committed
554
555
556
        sr1, sr2 = 16000, 8000
        tensor = common_utils.get_whitenoise(sample_rate=sr1)
        self._assert_consistency(T.Resample(float(sr1), float(sr2)), tensor)
557
558
559

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
560
        self._assert_consistency(T.ComplexNorm(), tensor)
561
562

    def test_MuLawEncoding(self):
moto's avatar
moto committed
563
        tensor = common_utils.get_whitenoise()
564
        self._assert_consistency(T.MuLawEncoding(), tensor)
565
566
567

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
568
        self._assert_consistency(T.MuLawDecoding(), tensor)
569
570
571
572
573
574

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
575
576
577
578
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
579
580

    def test_Fade(self):
moto's avatar
moto committed
581
        waveform = common_utils.get_whitenoise()
582
583
        fade_in_len = 3000
        fade_out_len = 3000
584
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
585
586
587

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
588
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
589
590
591

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
592
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
593
594

    def test_Vol(self):
moto's avatar
moto committed
595
        waveform = common_utils.get_whitenoise()
596
597
        self._assert_consistency(T.Vol(1.1), waveform)

wanglong001's avatar
wanglong001 committed
598
599
600
601
    def test_SlidingWindowCmn(self):
        tensor = torch.rand((1000, 10))
        self._assert_consistency(T.SlidingWindowCmn(), tensor)

Artyom Astafurov's avatar
Artyom Astafurov committed
602
    def test_Vad(self):
603
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
Artyom Astafurov's avatar
Artyom Astafurov committed
604
605
        waveform, sample_rate = torchaudio.load(filepath)
        self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)