autograd_impl.py 15.5 KB
Newer Older
1
from functools import partial
2
3
from typing import Callable, Tuple

4
import torch
5
import torchaudio.functional as F
6
7
from parameterized import parameterized
from torch import Tensor
8
from torch.autograd import gradcheck, gradgradcheck
mayp777's avatar
UPDATE  
mayp777 committed
9
10
11
12
13
14
15
16
from torchaudio_unittest.common_utils import (
    get_spectrogram,
    get_whitenoise,
    nested_params,
    rnnt_utils,
    TestBaseMixin,
    use_deterministic_algorithms,
)
17
18
19
20


class Autograd(TestBaseMixin):
    def assert_grad(
21
22
23
24
25
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        *,
        enable_all_grad: bool = True,
26
27
28
29
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
30
                i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
31
32
33
34
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
35
        assert gradgradcheck(transform, inputs_)
36

37
    def test_lfilter_x(self):
38
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
39
40
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
41
        x.requires_grad = True
42
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
43

44
    def test_lfilter_a(self):
45
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
46
47
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
48
        a.requires_grad = True
49
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
50

51
    def test_lfilter_b(self):
52
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
53
54
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
55
        b.requires_grad = True
56
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
57

58
    def test_lfilter_all_inputs(self):
59
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
60
61
62
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
63

64
65
    def test_lfilter_filterbanks(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
66
67
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
68
69
70
71
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
72
73
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
74
75
        self.assert_grad(F.lfilter, (x, a, b))

76
77
78
79
80
    def test_filtfilt_a(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        a.requires_grad = True
mayp777's avatar
UPDATE  
mayp777 committed
81
82
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
83
84
85
86
87
88

    def test_filtfilt_b(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        b.requires_grad = True
mayp777's avatar
UPDATE  
mayp777 committed
89
90
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
91
92
93
94
95

    def test_filtfilt_all_inputs(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
mayp777's avatar
UPDATE  
mayp777 committed
96
97
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b))
98
99
100

    def test_filtfilt_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
101
102
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
mayp777's avatar
UPDATE  
mayp777 committed
103
104
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b))
105

106
    def test_biquad(self):
107
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
108
109
110
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
111

112
113
114
115
116
117
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
118
    def test_band_biquad(self, central_freq, Q, noise):
119
        sr = 22050
120
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
121
122
123
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
124

125
126
127
128
129
130
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
131
    def test_bass_biquad(self, central_freq, Q, gain):
132
        sr = 22050
133
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
134
135
136
137
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
138

139
140
141
142
143
144
    @parameterized.expand(
        [
            (3000, 0.7, 10),
            (3000, 0.7, -10),
        ]
    )
145
    def test_treble_biquad(self, central_freq, Q, gain):
146
        sr = 22050
147
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
148
149
150
151
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
152

153
154
155
156
157
158
159
160
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
161
    def test_allpass_biquad(self, central_freq, Q):
162
        sr = 22050
163
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
164
165
166
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
167

168
169
170
171
172
173
174
175
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
176
    def test_lowpass_biquad(self, cutoff_freq, Q):
177
        sr = 22050
178
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
179
180
181
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
182

183
184
185
186
187
188
189
190
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
191
    def test_highpass_biquad(self, cutoff_freq, Q):
192
        sr = 22050
193
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
194
195
196
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
197

198
199
200
201
202
203
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
204
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
205
        sr = 22050
206
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
207
208
209
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
210

211
212
213
214
215
216
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
217
    def test_equalizer_biquad(self, central_freq, Q, gain):
218
        sr = 22050
219
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
220
221
222
223
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
224

225
226
227
228
229
230
231
232
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
233
    def test_bandreject_biquad(self, central_freq, Q):
234
        sr = 22050
235
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
236
237
238
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
239

moto's avatar
moto committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def test_deemph_biquad(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
        self.assert_grad(F.deemph_biquad, (x, 44100))

    def test_flanger(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.flanger, (x, 44100))

    def test_gain(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x, 1.1))

    def test_overdrive(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x,))

    @parameterized.expand([(True,), (False,)])
    def test_phaser(self, sinusoidal):
        sr = 8000
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
        self.assert_grad(F.phaser, (x, sr, sinusoidal))

262
263
264
265
266
267
268
269
270
271
272
273
274
275
    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_psd(self, use_mask):
        specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
        if use_mask:
            mask = torch.rand(10, 5)
        else:
            mask = None
        self.assert_grad(F.psd, (specgram, mask))

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    def test_mvdr_weights_souden(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))

    def test_mvdr_weights_souden_with_tensor(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    def test_mvdr_weights_rtf(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))

    def test_mvdr_weights_rtf_with_tensor(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        reference_channel = torch.zeros(batch_size, channel)
        reference_channel[..., 0].fill_(1)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))

310
311
    @parameterized.expand(
        [
312
313
            (1, True),
            (3, False),
314
315
        ]
    )
316
    def test_rtf_power(self, n_iter, diagonal_loading):
317
318
319
320
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
321
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))
322
323
324

    @parameterized.expand(
        [
325
326
            (1, True),
            (3, False),
327
328
        ]
    )
329
    def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
330
331
332
333
334
335
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
336
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))
337

338
339
340
341
342
343
344
345
    def test_apply_beamforming(self):
        sr = 8000
        n_fft = 400
        batch_size, num_channels = 2, 3
        n_fft_bin = n_fft // 2 + 1
        x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
        specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
        specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
346
        beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
347
348
        self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))

mayp777's avatar
UPDATE  
mayp777 committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    @nested_params(
        ["convolve", "fftconvolve"],
        ["full", "valid", "same"],
    )
    def test_convolve(self, fn, mode):
        leading_dims = (4, 3, 2)
        L_x, L_y = 23, 40
        x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
        y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
        self.assert_grad(getattr(F, fn), (x, y, mode))

    def test_add_noise(self):
        leading_dims = (5, 2, 3)
        L = 51
        waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
        noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
        lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
        snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
        self.assert_grad(F.add_noise, (waveform, noise, snr, lengths))

    def test_speed(self):
        leading_dims = (3, 2)
        T = 200
        waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
        lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
        self.assert_grad(F.speed, (waveform, 1000, 1.1, lengths), enable_all_grad=False)

    def test_preemphasis(self):
        waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
        coeff = 0.9
        self.assert_grad(F.preemphasis, (waveform, coeff))

    def test_deemphasis(self):
        waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
        coeff = 0.9
        self.assert_grad(F.deemphasis, (waveform, coeff))

    def test_frechet_distance(self):
        N = 16
        mu_x = torch.rand((N,))
        sigma_x = torch.rand((N, N))
        mu_y = torch.rand((N,))
        sigma_y = torch.rand((N, N))
        self.assert_grad(F.frechet_distance, (mu_x, sigma_x, mu_y, sigma_y))

394
395
396

class AutogradFloat32(TestBaseMixin):
    def assert_grad(
397
398
399
400
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        enable_all_grad: bool = True,
401
402
403
404
405
406
407
408
409
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
410
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
411

412
413
414
415
416
417
418
    @parameterized.expand(
        [
            (rnnt_utils.get_B1_T10_U3_D4_data,),
            (rnnt_utils.get_B2_T4_U3_D3_data,),
            (rnnt_utils.get_B1_T2_U3_D5_data,),
        ]
    )
419
420
421
422
423
424
425
426
427
428
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
429
430
431
432
433
            data["targets"],  # targets
            data["logit_lengths"],  # logit_lengths
            data["target_lengths"],  # target_lengths
            data["blank"],  # blank
            -1,  # clamp
434
435
436
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)