autograd_impl.py 15.3 KB
Newer Older
1
from functools import partial
2
3
from typing import Callable, Tuple

4
import torch
5
import torchaudio.functional as F
6
7
from parameterized import parameterized
from torch import Tensor
8
from torch.autograd import gradcheck, gradgradcheck
9
10
11
12
13
14
15
16
from torchaudio_unittest.common_utils import (
    get_spectrogram,
    get_whitenoise,
    nested_params,
    rnnt_utils,
    TestBaseMixin,
    use_deterministic_algorithms,
)
17
18
19
20


class Autograd(TestBaseMixin):
    def assert_grad(
21
22
23
24
25
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        *,
        enable_all_grad: bool = True,
26
27
28
29
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
30
                i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
31
32
33
34
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
35
        assert gradgradcheck(transform, inputs_)
36

37
    def test_lfilter_x(self):
38
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
39
40
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
41
        x.requires_grad = True
42
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
43

44
    def test_lfilter_a(self):
45
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
46
47
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
48
        a.requires_grad = True
49
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
50

51
    def test_lfilter_b(self):
52
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
53
54
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
55
        b.requires_grad = True
56
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
57

58
    def test_lfilter_all_inputs(self):
59
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
60
61
62
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
63

64
65
    def test_lfilter_filterbanks(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
66
67
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
68
69
70
71
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
72
73
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
74
75
        self.assert_grad(F.lfilter, (x, a, b))

76
77
78
79
80
    def test_filtfilt_a(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        a.requires_grad = True
81
82
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
83
84
85
86
87
88

    def test_filtfilt_b(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        b.requires_grad = True
89
90
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)
91
92
93
94
95

    def test_filtfilt_all_inputs(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
96
97
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b))
98
99
100

    def test_filtfilt_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
101
102
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
103
104
        with use_deterministic_algorithms(True, False):
            self.assert_grad(F.filtfilt, (x, a, b))
105

106
    def test_biquad(self):
107
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
108
109
110
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
111

112
113
114
115
116
117
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
118
    def test_band_biquad(self, central_freq, Q, noise):
119
        sr = 22050
120
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
121
122
123
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
124

125
126
127
128
129
130
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
131
    def test_bass_biquad(self, central_freq, Q, gain):
132
        sr = 22050
133
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
134
135
136
137
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
138

139
140
141
142
143
144
    @parameterized.expand(
        [
            (3000, 0.7, 10),
            (3000, 0.7, -10),
        ]
    )
145
    def test_treble_biquad(self, central_freq, Q, gain):
146
        sr = 22050
147
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
148
149
150
151
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
152

153
154
155
156
157
158
159
160
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
161
    def test_allpass_biquad(self, central_freq, Q):
162
        sr = 22050
163
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
164
165
166
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
167

168
169
170
171
172
173
174
175
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
176
    def test_lowpass_biquad(self, cutoff_freq, Q):
177
        sr = 22050
178
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
179
180
181
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
182

183
184
185
186
187
188
189
190
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
191
    def test_highpass_biquad(self, cutoff_freq, Q):
192
        sr = 22050
193
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
194
195
196
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
197

198
199
200
201
202
203
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
204
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
205
        sr = 22050
206
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
207
208
209
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
210

211
212
213
214
215
216
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
217
    def test_equalizer_biquad(self, central_freq, Q, gain):
218
        sr = 22050
219
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
220
221
222
223
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
224

225
226
227
228
229
230
231
232
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
233
    def test_bandreject_biquad(self, central_freq, Q):
234
        sr = 22050
235
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
236
237
238
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
239

moto's avatar
moto committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    def test_deemph_biquad(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
        self.assert_grad(F.deemph_biquad, (x, 44100))

    def test_flanger(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.flanger, (x, 44100))

    def test_gain(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x, 1.1))

    def test_overdrive(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x,))

    @parameterized.expand([(True,), (False,)])
    def test_phaser(self, sinusoidal):
        sr = 8000
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
        self.assert_grad(F.phaser, (x, sr, sinusoidal))

262
263
264
265
266
267
268
269
270
271
272
273
274
275
    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_psd(self, use_mask):
        specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
        if use_mask:
            mask = torch.rand(10, 5)
        else:
            mask = None
        self.assert_grad(F.psd, (specgram, mask))

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    def test_mvdr_weights_souden(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))

    def test_mvdr_weights_souden_with_tensor(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    def test_mvdr_weights_rtf(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))

    def test_mvdr_weights_rtf_with_tensor(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        reference_channel = torch.zeros(batch_size, channel)
        reference_channel[..., 0].fill_(1)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))

310
311
    @parameterized.expand(
        [
312
313
            (1, True),
            (3, False),
314
315
        ]
    )
316
    def test_rtf_power(self, n_iter, diagonal_loading):
317
318
319
320
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
321
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))
322
323
324

    @parameterized.expand(
        [
325
326
            (1, True),
            (3, False),
327
328
        ]
    )
329
    def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
330
331
332
333
334
335
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
336
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))
337

338
339
340
341
342
343
344
345
    def test_apply_beamforming(self):
        sr = 8000
        n_fft = 400
        batch_size, num_channels = 2, 3
        n_fft_bin = n_fft // 2 + 1
        x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
        specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
        specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
346
        beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
347
348
        self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))

349
    @nested_params(
350
        ["convolve", "fftconvolve"],
351
352
353
354
355
356
357
        ["full", "valid", "same"],
    )
    def test_convolve(self, fn, mode):
        leading_dims = (4, 3, 2)
        L_x, L_y = 23, 40
        x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
        y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
358
        self.assert_grad(getattr(F, fn), (x, y, mode))
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

    def test_add_noise(self):
        leading_dims = (5, 2, 3)
        L = 51
        waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
        noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
        lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
        snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
        self.assert_grad(F.add_noise, (waveform, noise, snr, lengths))

    def test_speed(self):
        leading_dims = (3, 2)
        T = 200
        waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
        lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
374
        self.assert_grad(F.speed, (waveform, 1000, 1.1, lengths), enable_all_grad=False)
375
376
377
378
379
380
381
382
383
384
385

    def test_preemphasis(self):
        waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
        coeff = 0.9
        self.assert_grad(F.preemphasis, (waveform, coeff))

    def test_deemphasis(self):
        waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
        coeff = 0.9
        self.assert_grad(F.deemphasis, (waveform, coeff))

386
387
388

class AutogradFloat32(TestBaseMixin):
    def assert_grad(
389
390
391
392
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        enable_all_grad: bool = True,
393
394
395
396
397
398
399
400
401
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
402
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
403

404
405
406
407
408
409
410
    @parameterized.expand(
        [
            (rnnt_utils.get_B1_T10_U3_D4_data,),
            (rnnt_utils.get_B2_T4_U3_D3_data,),
            (rnnt_utils.get_B1_T2_U3_D5_data,),
        ]
    )
411
412
413
414
415
416
417
418
419
420
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
421
422
423
424
425
            data["targets"],  # targets
            data["logit_lengths"],  # logit_lengths
            data["target_lengths"],  # target_lengths
            data["blank"],  # blank
            -1,  # clamp
426
427
428
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)