test_torchscript_consistency.py 13.7 KB
Newer Older
1
2
3
4
5
6
7
"""Test suites for jit-ability and its numerical compatibility"""
import os
import unittest

import torch
import torchaudio
import torchaudio.functional as F
8
import torchaudio.transforms as T
9
10
11
12

import common_utils


13
14
15
16
17
18
19
20
21
def _assert_transforms_consistency(transform, tensor, device):
    tensor = tensor.to(device)
    transform = transform.to(device)
    ts_transform = torch.jit.script(transform)
    output = transform(tensor)
    ts_output = ts_transform(tensor)
    torch.testing.assert_allclose(ts_output, output)


22
def _assert_functional_consistency(py_method, *args, shape_only=False, **kwargs):
23
24
25
26
27
    jit_method = torch.jit.script(py_method)

    jit_out = jit_method(*args, **kwargs)
    py_out = py_method(*args, **kwargs)

28
29
30
31
    if shape_only:
        assert jit_out.shape == py_out.shape, (jit_out.shape, py_out.shape)
    else:
        torch.testing.assert_allclose(jit_out, py_out)
32
33


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def _test_lfilter(waveform):
    """
    Design an IIR lowpass filter using scipy.signal filter design
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign

    Example
        >>> from scipy.signal import iirdesign
        >>> b, a = iirdesign(0.2, 0.3, 1, 60)
    """
    b_coeffs = torch.tensor(
        [
            0.00299893,
            -0.0051152,
            0.00841964,
            -0.00747802,
            0.00841964,
            -0.0051152,
            0.00299893,
        ],
        device=waveform.device,
    )
    a_coeffs = torch.tensor(
        [
            1.0,
            -4.8155751,
            10.2217618,
            -12.14481273,
            8.49018171,
            -3.3066882,
            0.56088705,
        ],
        device=waveform.device,
    )
67
    _assert_functional_consistency(F.lfilter, waveform, a_coeffs, b_coeffs)
68
69


70
71
72
73
74
75
76
77
78
79
80
81
class TestFunctional(unittest.TestCase):
    """Test functions in `functional` module."""
    def test_spectrogram(self):
        tensor = torch.rand((1, 1000))
        n_fft = 400
        ws = 400
        hop = 200
        pad = 0
        window = torch.hann_window(ws)
        power = 2
        normalize = False

82
        _assert_functional_consistency(
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize
        )

    def test_griffinlim(self):
        tensor = torch.rand((1, 201, 6))
        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000

98
        _assert_functional_consistency(
99
100
101
102
103
104
105
106
107
108
            F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0
        )

    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

109
        _assert_functional_consistency(F.compute_deltas, specgram, win_length=win_length)
110
111
112
113
114

    def test_detect_pitch_frequency(self):
        filepath = os.path.join(
            common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
        waveform, sample_rate = torchaudio.load(filepath)
115
        _assert_functional_consistency(F.detect_pitch_frequency, waveform, sample_rate)
116
117
118
119
120
121
122
123

    def test_create_fb_matrix(self):
        n_stft = 100
        f_min = 0.0
        f_max = 20.0
        n_mels = 10
        sample_rate = 16000

124
        _assert_functional_consistency(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate)
125
126
127
128
129
130
131
132

    def test_amplitude_to_DB(self):
        spec = torch.rand((6, 201))
        multiplier = 10.0
        amin = 1e-10
        db_multiplier = 0.0
        top_db = 80.0

133
        _assert_functional_consistency(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db)
134
135
136
137
138
139

    def test_DB_to_amplitude(self):
        x = torch.rand((1, 100))
        ref = 1.
        power = 1.

140
        _assert_functional_consistency(F.DB_to_amplitude, x, ref, power)
141
142
143
144
145
146

    def test_create_dct(self):
        n_mfcc = 40
        n_mels = 128
        norm = "ortho"

147
        _assert_functional_consistency(F.create_dct, n_mfcc, n_mels, norm)
148
149
150
151
152

    def test_mu_law_encoding(self):
        tensor = torch.rand((1, 10))
        qc = 256

153
        _assert_functional_consistency(F.mu_law_encoding, tensor, qc)
154
155
156
157
158

    def test_mu_law_decoding(self):
        tensor = torch.rand((1, 10))
        qc = 256

159
        _assert_functional_consistency(F.mu_law_decoding, tensor, qc)
160
161
162
163
164

    def test_complex_norm(self):
        complex_tensor = torch.randn(1, 2, 1025, 400, 2)
        power = 2

165
        _assert_functional_consistency(F.complex_norm, complex_tensor, power)
166
167
168
169
170
171
172

    def test_mask_along_axis(self):
        specgram = torch.randn(2, 1025, 400)
        mask_param = 100
        mask_value = 30.
        axis = 2

173
        _assert_functional_consistency(F.mask_along_axis, specgram, mask_param, mask_value, axis)
174
175
176
177
178
179
180

    def test_mask_along_axis_iid(self):
        specgrams = torch.randn(4, 2, 1025, 400)
        mask_param = 100
        mask_value = 30.
        axis = 2

181
        _assert_functional_consistency(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
182
183
184
185
186

    def test_gain(self):
        tensor = torch.rand((1, 1000))
        gainDB = 2.0

187
        _assert_functional_consistency(F.gain, tensor, gainDB)
188
189
190
191

    def test_dither(self):
        tensor = torch.rand((2, 1000))

192
193
194
        _assert_functional_consistency(F.dither, tensor, shape_only=True)
        _assert_functional_consistency(F.dither, tensor, "RPDF", shape_only=True)
        _assert_functional_consistency(F.dither, tensor, "GPDF", shape_only=True)
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    def test_lfilter(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
        waveform, _ = torchaudio.load(filepath, normalization=True)
        _test_lfilter(waveform)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_lfilter_cuda(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)
        _test_lfilter(waveform.cuda(device=torch.device("cuda:0")))

    def test_lowpass(self):
        cutoff_freq = 3000

        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
212
        _assert_functional_consistency(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
213
214
215
216
217
218

    def test_highpass(self):
        cutoff_freq = 2000

        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
219
        _assert_functional_consistency(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
220
221
222
223
224
225
226

    def test_allpass(self):
        central_freq = 1000
        q = 0.707

        filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
227
        _assert_functional_consistency(F.allpass_biquad, waveform, sample_rate, central_freq, q)
228
229
230
231
232
233
234
235

    def test_bandpass_with_csg(self):
        central_freq = 1000
        q = 0.707
        const_skirt_gain = True

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
236
        _assert_functional_consistency(
237
238
239
240
241
242
243
244
245
            F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)

    def test_bandpass_withou_csg(self):
        central_freq = 1000
        q = 0.707
        const_skirt_gain = False

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
246
        _assert_functional_consistency(
247
248
249
250
251
252
253
254
            F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)

    def test_bandreject(self):
        central_freq = 1000
        q = 0.707

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
255
        _assert_functional_consistency(
256
257
258
259
260
261
262
263
264
            F.bandreject_biquad, waveform, sample_rate, central_freq, q)

    def test_band_with_noise(self):
        central_freq = 1000
        q = 0.707
        noise = True

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
265
        _assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
266
267
268
269
270
271
272
273

    def test_band_without_noise(self):
        central_freq = 1000
        q = 0.707
        noise = False

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
274
        _assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
275
276
277
278
279
280
281
282

    def test_treble(self):
        gain = 40
        central_freq = 1000
        q = 0.707

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
283
        _assert_functional_consistency(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
284
285
286
287

    def test_deemph(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
288
        _assert_functional_consistency(F.deemph_biquad, waveform, sample_rate)
289
290
291
292

    def test_riaa(self):
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
293
        _assert_functional_consistency(F.riaa_biquad, waveform, sample_rate)
294
295
296
297
298
299
300
301

    def test_equalizer(self):
        center_freq = 300
        gain = 1
        q = 0.707

        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)
302
        _assert_functional_consistency(
303
304
305
306
307
308
309
            F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)

    def test_perf_biquad_filtering(self):
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)
310
        _assert_functional_consistency(F.lfilter, waveform, a, b)
311

312

313
314
315
class _TransformsTestMixin:
    """Implements test for Transforms that are performed for different devices"""
    device = None
316

317
318
    def _assert_consistency(self, transform, tensor):
        _assert_transforms_consistency(transform, tensor, self.device)
319
320
321

    def test_Spectrogram(self):
        tensor = torch.rand((1, 1000))
322
        self._assert_consistency(T.Spectrogram(), tensor)
323
324
325

    def test_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
326
        self._assert_consistency(T.GriffinLim(length=1000, rand_init=False), tensor)
327
328
329

    def test_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
330
        self._assert_consistency(T.AmplitudeToDB(), spec)
331
332
333

    def test_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
334
        self._assert_consistency(T.MelScale(), spec_f)
335
336
337

    def test_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
338
        self._assert_consistency(T.MelSpectrogram(), tensor)
339
340
341

    def test_MFCC(self):
        tensor = torch.rand((1, 1000))
342
        self._assert_consistency(T.MFCC(), tensor)
343
344
345
346
347

    def test_Resample(self):
        tensor = torch.rand((2, 1000))
        sample_rate = 100.
        sample_rate_2 = 50.
348
        self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
349
350
351

    def test_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
352
        self._assert_consistency(T.ComplexNorm(), tensor)
353
354
355

    def test_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
356
        self._assert_consistency(T.MuLawEncoding(), tensor)
357
358
359

    def test_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
360
        self._assert_consistency(T.MuLawDecoding(), tensor)
361
362
363
364
365
366

    def test_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
367
368
369
370
        self._assert_consistency(
            T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
            tensor,
        )
371
372
373
374
375
376
377

    def test_Fade(self):
        test_filepath = os.path.join(
            common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000
378
        self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
379
380
381

    def test_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
382
        self._assert_consistency(T.FrequencyMasking(freq_mask_param=60, iid_masks=False), tensor)
383
384
385

    def test_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
386
        self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
387
388
389
390
391

    def test_Vol(self):
        test_filepath = os.path.join(
            common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)
392
393
394
395
396
397
398
399
400
401
402
403
        self._assert_consistency(T.Vol(1.1), waveform)


class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
    """Test suite for Transforms module on CPU"""
    device = torch.device('cpu')


@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
    """Test suite for Transforms module on GPU"""
    device = torch.device('cuda')
Vincent QB's avatar
Vincent QB committed
404
405
406
407


if __name__ == '__main__':
    unittest.main()