torchscript_consistency_impl.py 19.5 KB
Newer Older
1
2
3
4
5
"""Test suites for jit-ability and its numerical compatibility"""
import unittest

import torch
import torchaudio.functional as F
6
import torchaudio.transforms as T
7

8
from . import common_utils
9
10


11
class Functional(common_utils.TestBaseMixin):
12
13
    """Implements test for `functinoal` modul that are performed for different devices"""
    def _assert_consistency(self, func, tensor, shape_only=False):
14
        tensor = tensor.to(device=self.device, dtype=self.dtype)
15
16
17
18
19
20
21
22

        ts_func = torch.jit.script(func)
        output = func(tensor)
        ts_output = ts_func(tensor)
        if shape_only:
            ts_output = ts_output.shape
            output = output.shape
        self.assertEqual(ts_output, output)
23
24

    def test_spectrogram(self):
25
26
27
28
29
30
31
32
33
34
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            pad = 0
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)

moto's avatar
moto committed
35
        tensor = common_utils.get_whitenoise()
36
        self._assert_consistency(func, tensor)
37
38

    def test_griffinlim(self):
39
40
41
42
43
44
45
46
47
48
49
50
51
        def func(tensor):
            n_fft = 400
            ws = 400
            hop = 200
            window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
            power = 2.
            normalize = False
            momentum = 0.99
            n_iter = 32
            length = 1000
            rand_int = False
            return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)

52
        tensor = torch.rand((1, 201, 6))
53
        self._assert_consistency(func, tensor)
54
55

    def test_compute_deltas(self):
56
57
58
59
        def func(tensor):
            win_length = 2 * 7 + 1
            return F.compute_deltas(tensor, win_length=win_length)

60
61
62
        channel = 13
        n_mfcc = channel * 3
        time = 1021
63
64
        tensor = torch.randn(channel, n_mfcc, time)
        self._assert_consistency(func, tensor)
65
66

    def test_detect_pitch_frequency(self):
moto's avatar
moto committed
67
        waveform = common_utils.get_sinusoid(sample_rate=44100)
68
69
70
71
72
73

        def func(tensor):
            sample_rate = 44100
            return F.detect_pitch_frequency(tensor, sample_rate)

        self._assert_consistency(func, waveform)
74
75

    def test_create_fb_matrix(self):
76
77
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')
78

79
80
81
82
83
84
        def func(_):
            n_stft = 100
            f_min = 0.0
            f_max = 20.0
            n_mels = 10
            sample_rate = 16000
85
            norm = "slaney"
Vincent QB's avatar
Vincent QB committed
86
            return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate, norm)
87
88
89

        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
90
91

    def test_amplitude_to_DB(self):
92
93
94
95
96
97
        def func(tensor):
            multiplier = 10.0
            amin = 1e-10
            db_multiplier = 0.0
            top_db = 80.0
            return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
98

99
100
        tensor = torch.rand((6, 201))
        self._assert_consistency(func, tensor)
101
102

    def test_DB_to_amplitude(self):
103
104
105
106
        def func(tensor):
            ref = 1.
            power = 1.
            return F.DB_to_amplitude(tensor, ref, power)
107

108
109
        tensor = torch.rand((1, 100))
        self._assert_consistency(func, tensor)
110
111

    def test_create_dct(self):
112
113
114
115
116
117
118
119
        if self.device != torch.device('cpu'):
            raise unittest.SkipTest('No need to perform test on device other than CPU')

        def func(_):
            n_mfcc = 40
            n_mels = 128
            norm = "ortho"
            return F.create_dct(n_mfcc, n_mels, norm)
120

121
122
        dummy = torch.zeros(1, 1)
        self._assert_consistency(func, dummy)
123
124

    def test_mu_law_encoding(self):
125
126
127
        def func(tensor):
            qc = 256
            return F.mu_law_encoding(tensor, qc)
128

moto's avatar
moto committed
129
130
        waveform = common_utils.get_whitenoise()
        self._assert_consistency(func, waveform)
131
132

    def test_mu_law_decoding(self):
133
134
135
        def func(tensor):
            qc = 256
            return F.mu_law_decoding(tensor, qc)
136

137
138
        tensor = torch.rand((1, 10))
        self._assert_consistency(func, tensor)
139
140

    def test_complex_norm(self):
141
142
143
        def func(tensor):
            power = 2.
            return F.complex_norm(tensor, power)
144

145
        tensor = torch.randn(1, 2, 1025, 400, 2)
146
        self._assert_consistency(func, tensor)
147
148

    def test_mask_along_axis(self):
149
150
151
152
153
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis(tensor, mask_param, mask_value, axis)
154

155
156
        tensor = torch.randn(2, 1025, 400)
        self._assert_consistency(func, tensor)
157
158

    def test_mask_along_axis_iid(self):
159
160
161
162
163
        def func(tensor):
            mask_param = 100
            mask_value = 30.
            axis = 2
            return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
164

165
166
        tensor = torch.randn(4, 2, 1025, 400)
        self._assert_consistency(func, tensor)
167
168

    def test_gain(self):
169
170
171
172
        def func(tensor):
            gainDB = 2.0
            return F.gain(tensor, gainDB)

173
        tensor = torch.rand((1, 1000))
174
175
176
177
178
179
        self._assert_consistency(func, tensor)

    def test_dither_TPDF(self):
        def func(tensor):
            return F.dither(tensor, 'TPDF')

moto's avatar
moto committed
180
        tensor = common_utils.get_whitenoise(n_channels=2)
181
        self._assert_consistency(func, tensor, shape_only=True)
182

183
184
185
    def test_dither_RPDF(self):
        def func(tensor):
            return F.dither(tensor, 'RPDF')
186

moto's avatar
moto committed
187
        tensor = common_utils.get_whitenoise(n_channels=2)
188
        self._assert_consistency(func, tensor, shape_only=True)
189

190
191
192
193
    def test_dither_GPDF(self):
        def func(tensor):
            return F.dither(tensor, 'GPDF')

moto's avatar
moto committed
194
        tensor = common_utils.get_whitenoise(n_channels=2)
195
        self._assert_consistency(func, tensor, shape_only=True)
196

197
    def test_lfilter(self):
198
        if self.dtype == torch.float64:
199
            raise unittest.SkipTest("This test is known to fail for float64")
200

moto's avatar
moto committed
201
        waveform = common_utils.get_whitenoise()
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
239
240

    def test_lowpass(self):
241
        if self.dtype == torch.float64:
242
            raise unittest.SkipTest("This test is known to fail for float64")
243

moto's avatar
moto committed
244
        waveform = common_utils.get_whitenoise(sample_rate=44100)
245

246
247
248
249
250
251
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 3000.
            return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
252

253
    def test_highpass(self):
254
        if self.dtype == torch.float64:
255
            raise unittest.SkipTest("This test is known to fail for float64")
256

moto's avatar
moto committed
257
        waveform = common_utils.get_whitenoise(sample_rate=44100)
258

259
260
261
262
        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
263

264
265
266
        self._assert_consistency(func, waveform)

    def test_allpass(self):
267
        if self.dtype == torch.float64:
268
            raise unittest.SkipTest("This test is known to fail for float64")
269

moto's avatar
moto committed
270
        waveform = common_utils.get_whitenoise(sample_rate=44100)
271

272
273
274
275
276
277
278
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.allpass_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
279

280
    def test_bandpass_with_csg(self):
281
        if self.dtype == torch.float64:
282
            raise unittest.SkipTest("This test is known to fail for float64")
283

moto's avatar
moto committed
284
        waveform = common_utils.get_whitenoise(sample_rate=44100)
285

286
287
288
289
290
291
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
292

293
294
        self._assert_consistency(func, waveform)

295
296
    def test_bandpass_without_csg(self):
        if self.dtype == torch.float64:
297
            raise unittest.SkipTest("This test is known to fail for float64")
298

moto's avatar
moto committed
299
        waveform = common_utils.get_whitenoise(sample_rate=44100)
300

301
302
303
304
305
306
307
308
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)

        self._assert_consistency(func, waveform)
309

310
    def test_bandreject(self):
311
        if self.dtype == torch.float64:
312
            raise unittest.SkipTest("This test is known to fail for float64")
313

moto's avatar
moto committed
314
        waveform = common_utils.get_whitenoise(sample_rate=44100)
315

316
317
318
319
320
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
321

322
323
324
        self._assert_consistency(func, waveform)

    def test_band_with_noise(self):
325
        if self.dtype == torch.float64:
326
            raise unittest.SkipTest("This test is known to fail for float64")
327

moto's avatar
moto committed
328
        waveform = common_utils.get_whitenoise(sample_rate=44100)
329

330
331
332
333
334
335
336
337
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = True
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
338

339
    def test_band_without_noise(self):
340
        if self.dtype == torch.float64:
341
            raise unittest.SkipTest("This test is known to fail for float64")
342

moto's avatar
moto committed
343
        waveform = common_utils.get_whitenoise(sample_rate=44100)
344

345
346
347
348
349
350
351
352
        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            noise = False
            return F.band_biquad(tensor, sample_rate, central_freq, q, noise)

        self._assert_consistency(func, waveform)
353

354
    def test_treble(self):
355
        if self.dtype == torch.float64:
356
            raise unittest.SkipTest("This test is known to fail for float64")
357

moto's avatar
moto committed
358
        waveform = common_utils.get_whitenoise(sample_rate=44100)
359
360
361
362
363
364
365
366
367

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)
368

jimchen90's avatar
jimchen90 committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def test_bass(self):
        if self.dtype == torch.float64:
            raise unittest.SkipTest("This test is known to fail for float64")

        waveform = common_utils.get_whitenoise(sample_rate=44100)

        def func(tensor):
            sample_rate = 44100
            gain = 40.
            central_freq = 1000.
            q = 0.707
            return F.bass_biquad(tensor, sample_rate, gain, central_freq, q)

        self._assert_consistency(func, waveform)

384
    def test_deemph(self):
385
        if self.dtype == torch.float64:
386
            raise unittest.SkipTest("This test is known to fail for float64")
387

moto's avatar
moto committed
388
        waveform = common_utils.get_whitenoise(sample_rate=44100)
389
390
391
392
393
394

        def func(tensor):
            sample_rate = 44100
            return F.deemph_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
395
396

    def test_riaa(self):
397
        if self.dtype == torch.float64:
398
            raise unittest.SkipTest("This test is known to fail for float64")
399

moto's avatar
moto committed
400
        waveform = common_utils.get_whitenoise(sample_rate=44100)
401

402
403
404
405
406
        def func(tensor):
            sample_rate = 44100
            return F.riaa_biquad(tensor, sample_rate)

        self._assert_consistency(func, waveform)
407

408
    def test_equalizer(self):
409
        if self.dtype == torch.float64:
410
            raise unittest.SkipTest("This test is known to fail for float64")
411

moto's avatar
moto committed
412
        waveform = common_utils.get_whitenoise(sample_rate=44100)
413
414
415
416
417
418
419
420
421

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)

        self._assert_consistency(func, waveform)
422
423

    def test_perf_biquad_filtering(self):
424
        if self.dtype == torch.float64:
425
            raise unittest.SkipTest("This test is known to fail for float64")
426

moto's avatar
moto committed
427
        waveform = common_utils.get_whitenoise()
428
429
430
431
432
433
434

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
435

wanglong001's avatar
wanglong001 committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    def test_sliding_window_cmn(self):
        def func(tensor):
            cmn_window = 600
            min_cmn_window = 100
            center = False
            norm_vars = False
            a = torch.tensor(
                [
                    [
                        -1.915875792503357,
                        1.147700309753418
                    ],
                    [
                        1.8242558240890503,
                        1.3869990110397339
                    ]
                ],
                device=tensor.device,
                dtype=tensor.dtype
            )
            return F.sliding_window_cmn(a, cmn_window, min_cmn_window, center, norm_vars)
        b = torch.tensor(
            [
                [
                    -1.8701,
                    -0.1196
                ],
                [
                    1.8701,
                    0.1196
                ]
            ]
        )
        self._assert_consistency(func, b)

471
    def test_contrast(self):
moto's avatar
moto committed
472
        waveform = common_utils.get_whitenoise()
473
474
475
476
477
478
479

        def func(tensor):
            enhancement_amount = 80.
            return F.contrast(tensor, enhancement_amount)

        self._assert_consistency(func, waveform)

480
    def test_dcshift(self):
moto's avatar
moto committed
481
        waveform = common_utils.get_whitenoise()
482
483
484
485
486
487
488
489

        def func(tensor):
            shift = 0.5
            limiter_gain = 0.05
            return F.dcshift(tensor, shift, limiter_gain)

        self._assert_consistency(func, waveform)

490
    def test_overdrive(self):
moto's avatar
moto committed
491
        waveform = common_utils.get_whitenoise()
492
493
494
495
496
497
498
499

        def func(tensor):
            gain = 30.
            colour = 50.
            return F.overdrive(tensor, gain, colour)

        self._assert_consistency(func, waveform)

500
    def test_phaser(self):
moto's avatar
moto committed
501
        waveform = common_utils.get_whitenoise(sample_rate=44100)
502
503
504
505
506
507
508
509
510
511
512
513

        def func(tensor):
            gain_in = 0.5
            gain_out = 0.8
            delay_ms = 2.0
            decay = 0.4
            speed = 0.5
            sample_rate = 44100
            return F.phaser(tensor, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=True)

        self._assert_consistency(func, waveform)

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    def test_flanger(self):
        torch.random.manual_seed(40)
        waveform = torch.rand(2, 100) - 0.5

        def func(tensor):
            delay = 0.8
            depth = 0.88
            regen = 3.0
            width = 0.23
            speed = 1.3
            phase = 60.
            sample_rate = 44100
            return F.flanger(tensor, sample_rate, delay, depth, regen, width, speed,
                             phase, modulation='sinusoidal', interpolation='linear')

        self._assert_consistency(func, waveform)

531

532
class Transforms(common_utils.TestBaseMixin):
533
534
    """Implements test for Transforms that are performed for different devices"""
    def _assert_consistency(self, transform, tensor):
535
536
        tensor = tensor.to(device=self.device, dtype=self.dtype)
        transform = transform.to(device=self.device, dtype=self.dtype)
537
538
539
540
541

        ts_transform = torch.jit.script(transform)
        output = transform(tensor)
        ts_output = ts_transform(tensor)
        self.assertEqual(ts_output, output)
542
543
544

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
545
        self._assert_consistency(T.Spectrogram(), tensor)
546
547
548

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
549
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
550
551
552

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
553
        self._assert_consistency(T.AmplitudeToDB(), spec)
554
555
556

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
557
        self._assert_consistency(T.MelScale(), spec_f)
558
559
560

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
561
        self._assert_consistency(T.MelSpectrogram(), tensor)
562
563
564

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
565
        self._assert_consistency(T.MFCC(), tensor)
566
567

    def test_Resample(self):
moto's avatar
moto committed
568
569
570
        sr1, sr2 = 16000, 8000
        tensor = common_utils.get_whitenoise(sample_rate=sr1)
        self._assert_consistency(T.Resample(float(sr1), float(sr2)), tensor)
571
572
573

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
574
        self._assert_consistency(T.ComplexNorm(), tensor)
575
576

    def test_MuLawEncoding(self):
moto's avatar
moto committed
577
        tensor = common_utils.get_whitenoise()
578
        self._assert_consistency(T.MuLawEncoding(), tensor)
579
580
581

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
582
        self._assert_consistency(T.MuLawDecoding(), tensor)
583
584
585
586
587
588

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
589
590
591
592
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
593
594

    def test_Fade(self):
moto's avatar
moto committed
595
        waveform = common_utils.get_whitenoise()
596
597
        fade_in_len = 3000
        fade_out_len = 3000
598
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
599
600
601

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
602
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
603
604
605

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
606
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
607
608

    def test_Vol(self):
moto's avatar
moto committed
609
        waveform = common_utils.get_whitenoise()
610
611
        self._assert_consistency(T.Vol(1.1), waveform)

wanglong001's avatar
wanglong001 committed
612
613
614
615
    def test_SlidingWindowCmn(self):
        tensor = torch.rand((1000, 10))
        self._assert_consistency(T.SlidingWindowCmn(), tensor)

Artyom Astafurov's avatar
Artyom Astafurov committed
616
    def test_Vad(self):
617
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
618
        waveform, sample_rate = common_utils.load_wav(filepath)
Artyom Astafurov's avatar
Artyom Astafurov committed
619
        self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)