autograd_impl.py 15 KB
Newer Older
1
from functools import partial
2
3
from typing import Callable, Tuple

4
import torch
5
import torchaudio.functional as F
6
7
from parameterized import parameterized
from torch import Tensor
8
from torch.autograd import gradcheck, gradgradcheck
9
from torchaudio_unittest.common_utils import get_spectrogram, get_whitenoise, nested_params, rnnt_utils, TestBaseMixin
10
11
12
13


class Autograd(TestBaseMixin):
    def assert_grad(
14
15
16
17
18
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        *,
        enable_all_grad: bool = True,
19
20
21
22
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
23
                i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
24
25
26
27
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
28
        assert gradgradcheck(transform, inputs_)
29

30
    def test_lfilter_x(self):
31
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
32
33
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
34
        x.requires_grad = True
35
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
36

37
    def test_lfilter_a(self):
38
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
39
40
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
41
        a.requires_grad = True
42
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
43

44
    def test_lfilter_b(self):
45
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
46
47
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
48
        b.requires_grad = True
49
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
50

51
    def test_lfilter_all_inputs(self):
52
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
53
54
55
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
56

57
58
    def test_lfilter_filterbanks(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
59
60
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
61
62
63
64
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
65
66
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
67
68
        self.assert_grad(F.lfilter, (x, a, b))

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def test_filtfilt_a(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        a.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_b(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        b.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_all_inputs(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.filtfilt, (x, a, b))

    def test_filtfilt_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
91
92
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
93
94
        self.assert_grad(F.filtfilt, (x, a, b))

95
    def test_biquad(self):
96
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
97
98
99
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
100

101
102
103
104
105
106
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
107
    def test_band_biquad(self, central_freq, Q, noise):
108
        sr = 22050
109
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
110
111
112
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
113

114
115
116
117
118
119
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
120
    def test_bass_biquad(self, central_freq, Q, gain):
121
        sr = 22050
122
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
123
124
125
126
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
127

128
129
130
131
132
133
    @parameterized.expand(
        [
            (3000, 0.7, 10),
            (3000, 0.7, -10),
        ]
    )
134
    def test_treble_biquad(self, central_freq, Q, gain):
135
        sr = 22050
136
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
137
138
139
140
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
141

142
143
144
145
146
147
148
149
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
150
    def test_allpass_biquad(self, central_freq, Q):
151
        sr = 22050
152
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
153
154
155
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
156

157
158
159
160
161
162
163
164
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
165
    def test_lowpass_biquad(self, cutoff_freq, Q):
166
        sr = 22050
167
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
168
169
170
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
171

172
173
174
175
176
177
178
179
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
180
    def test_highpass_biquad(self, cutoff_freq, Q):
181
        sr = 22050
182
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
183
184
185
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
186

187
188
189
190
191
192
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
193
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
194
        sr = 22050
195
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
196
197
198
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
199

200
201
202
203
204
205
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
206
    def test_equalizer_biquad(self, central_freq, Q, gain):
207
        sr = 22050
208
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
209
210
211
212
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
213

214
215
216
217
218
219
220
221
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
222
    def test_bandreject_biquad(self, central_freq, Q):
223
        sr = 22050
224
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
225
226
227
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
228

moto's avatar
moto committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    def test_deemph_biquad(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
        self.assert_grad(F.deemph_biquad, (x, 44100))

    def test_flanger(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.flanger, (x, 44100))

    def test_gain(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x, 1.1))

    def test_overdrive(self):
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x,))

    @parameterized.expand([(True,), (False,)])
    def test_phaser(self, sinusoidal):
        sr = 8000
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
        self.assert_grad(F.phaser, (x, sr, sinusoidal))

251
252
253
254
255
256
257
258
259
260
261
262
263
264
    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_psd(self, use_mask):
        specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
        if use_mask:
            mask = torch.rand(10, 5)
        else:
            mask = None
        self.assert_grad(F.psd, (specgram, mask))

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    def test_mvdr_weights_souden(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))

    def test_mvdr_weights_souden_with_tensor(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def test_mvdr_weights_rtf(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))

    def test_mvdr_weights_rtf_with_tensor(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        reference_channel = torch.zeros(batch_size, channel)
        reference_channel[..., 0].fill_(1)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))

299
300
    @parameterized.expand(
        [
301
302
            (1, True),
            (3, False),
303
304
        ]
    )
305
    def test_rtf_power(self, n_iter, diagonal_loading):
306
307
308
309
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
310
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))
311
312
313

    @parameterized.expand(
        [
314
315
            (1, True),
            (3, False),
316
317
        ]
    )
318
    def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
319
320
321
322
323
324
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
325
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))
326

327
328
329
330
331
332
333
334
    def test_apply_beamforming(self):
        sr = 8000
        n_fft = 400
        batch_size, num_channels = 2, 3
        n_fft_bin = n_fft // 2 + 1
        x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
        specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
        specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
335
        beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
336
337
        self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    @nested_params(
        [F.convolve, F.fftconvolve],
        ["full", "valid", "same"],
    )
    def test_convolve(self, fn, mode):
        leading_dims = (4, 3, 2)
        L_x, L_y = 23, 40
        x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
        y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
        self.assert_grad(fn, (x, y, mode))

    def test_add_noise(self):
        leading_dims = (5, 2, 3)
        L = 51
        waveform = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
        noise = torch.rand(*leading_dims, L, dtype=self.dtype, device=self.device)
        lengths = torch.rand(*leading_dims, dtype=self.dtype, device=self.device)
        snr = torch.rand(*leading_dims, dtype=self.dtype, device=self.device) * 10
        self.assert_grad(F.add_noise, (waveform, noise, snr, lengths))

    def test_speed(self):
        leading_dims = (3, 2)
        T = 200
        waveform = torch.rand(*leading_dims, T, dtype=self.dtype, device=self.device, requires_grad=True)
        lengths = torch.randint(1, T, leading_dims, dtype=self.dtype, device=self.device)
        self.assert_grad(F.speed, (waveform, lengths, 1000, 1.1), enable_all_grad=False)

    def test_preemphasis(self):
        waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
        coeff = 0.9
        self.assert_grad(F.preemphasis, (waveform, coeff))

    def test_deemphasis(self):
        waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype, requires_grad=True)
        coeff = 0.9
        self.assert_grad(F.deemphasis, (waveform, coeff))

375
376
377

class AutogradFloat32(TestBaseMixin):
    def assert_grad(
378
379
380
381
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        enable_all_grad: bool = True,
382
383
384
385
386
387
388
389
390
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
391
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
392

393
394
395
396
397
398
399
    @parameterized.expand(
        [
            (rnnt_utils.get_B1_T10_U3_D4_data,),
            (rnnt_utils.get_B2_T4_U3_D3_data,),
            (rnnt_utils.get_B1_T2_U3_D5_data,),
        ]
    )
400
401
402
403
404
405
406
407
408
409
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
410
411
412
413
414
            data["targets"],  # targets
            data["logit_lengths"],  # logit_lengths
            data["target_lengths"],  # target_lengths
            data["blank"],  # blank
            -1,  # clamp
415
416
417
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)