autograd_impl.py 13.5 KB
Newer Older
1
from functools import partial
2
3
from typing import Callable, Tuple

4
import torch
5
import torchaudio.functional as F
6
7
from parameterized import parameterized
from torch import Tensor
8
from torch.autograd import gradcheck, gradgradcheck
9
from torchaudio_unittest.common_utils import (
10
    get_spectrogram,
11
    get_whitenoise,
12
    rnnt_utils,
13
    TestBaseMixin,
14
15
16
17
18
)


class Autograd(TestBaseMixin):
    def assert_grad(
19
20
21
22
23
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        *,
        enable_all_grad: bool = True,
24
25
26
27
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
28
                i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
29
30
31
32
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
33
        assert gradgradcheck(transform, inputs_)
34

35
    def test_lfilter_x(self):
36
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
37
38
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
39
        x.requires_grad = True
40
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
41

42
    def test_lfilter_a(self):
43
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
44
45
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
46
        a.requires_grad = True
47
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
48

49
    def test_lfilter_b(self):
50
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
51
52
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
53
        b.requires_grad = True
54
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
55

56
    def test_lfilter_all_inputs(self):
57
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
58
59
60
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
61

62
63
    def test_lfilter_filterbanks(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
64
65
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
66
67
68
69
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
70
71
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
72
73
        self.assert_grad(F.lfilter, (x, a, b))

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    def test_filtfilt_a(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        a.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_b(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        b.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_all_inputs(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.filtfilt, (x, a, b))

    def test_filtfilt_batching(self):
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
96
97
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
98
99
        self.assert_grad(F.filtfilt, (x, a, b))

100
    def test_biquad(self):
101
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
102
103
104
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
105

106
107
108
109
110
111
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
112
    def test_band_biquad(self, central_freq, Q, noise):
113
        sr = 22050
114
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
115
116
117
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
118

119
120
121
122
123
124
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
125
    def test_bass_biquad(self, central_freq, Q, gain):
126
        sr = 22050
127
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
128
129
130
131
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
132

133
134
135
136
137
138
    @parameterized.expand(
        [
            (3000, 0.7, 10),
            (3000, 0.7, -10),
        ]
    )
139
    def test_treble_biquad(self, central_freq, Q, gain):
140
        sr = 22050
141
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
142
143
144
145
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
146

147
148
149
150
151
152
153
154
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
155
    def test_allpass_biquad(self, central_freq, Q):
156
        sr = 22050
157
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
158
159
160
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
161

162
163
164
165
166
167
168
169
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
170
    def test_lowpass_biquad(self, cutoff_freq, Q):
171
        sr = 22050
172
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
173
174
175
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
176

177
178
179
180
181
182
183
184
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
185
    def test_highpass_biquad(self, cutoff_freq, Q):
186
        sr = 22050
187
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
188
189
190
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
191

192
193
194
195
196
197
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
198
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
199
        sr = 22050
200
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
201
202
203
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
204

205
206
207
208
209
210
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
211
    def test_equalizer_biquad(self, central_freq, Q, gain):
212
        sr = 22050
213
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
214
215
216
217
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
218

219
220
221
222
223
224
225
226
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
227
    def test_bandreject_biquad(self, central_freq, Q):
228
        sr = 22050
229
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
230
231
232
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
233

moto's avatar
moto committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    def test_deemph_biquad(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
        self.assert_grad(F.deemph_biquad, (x, 44100))

    def test_flanger(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.flanger, (x, 44100))

    def test_gain(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x, 1.1))

    def test_overdrive(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=8000, duration=0.01, n_channels=1)
        self.assert_grad(F.gain, (x,))

    @parameterized.expand([(True,), (False,)])
    def test_phaser(self, sinusoidal):
        torch.random.manual_seed(2434)
        sr = 8000
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
        self.assert_grad(F.phaser, (x, sr, sinusoidal))

261
262
263
264
265
266
267
268
269
270
271
272
273
274
    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_psd(self, use_mask):
        specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
        if use_mask:
            mask = torch.rand(10, 5)
        else:
            mask = None
        self.assert_grad(F.psd, (specgram, mask))

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    def test_mvdr_weights_souden(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))

    def test_mvdr_weights_souden_with_tensor(self):
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    def test_mvdr_weights_rtf(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))

    def test_mvdr_weights_rtf_with_tensor(self):
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        reference_channel = torch.zeros(batch_size, channel)
        reference_channel[..., 0].fill_(1)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))

309
310
    @parameterized.expand(
        [
311
312
            (1, True),
            (3, False),
313
314
        ]
    )
315
    def test_rtf_power(self, n_iter, diagonal_loading):
316
317
318
319
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
320
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))
321
322
323

    @parameterized.expand(
        [
324
325
            (1, True),
            (3, False),
326
327
        ]
    )
328
    def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
329
330
331
332
333
334
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
335
        self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))
336

337
338
339
340
341
342
343
344
    def test_apply_beamforming(self):
        sr = 8000
        n_fft = 400
        batch_size, num_channels = 2, 3
        n_fft_bin = n_fft // 2 + 1
        x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
        specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
        specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
345
        beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
346
347
        self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))

348
349
350

class AutogradFloat32(TestBaseMixin):
    def assert_grad(
351
352
353
354
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        enable_all_grad: bool = True,
355
356
357
358
359
360
361
362
363
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
364
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
365

366
367
368
369
370
371
372
    @parameterized.expand(
        [
            (rnnt_utils.get_B1_T10_U3_D4_data,),
            (rnnt_utils.get_B2_T4_U3_D3_data,),
            (rnnt_utils.get_B1_T2_U3_D5_data,),
        ]
    )
373
374
375
376
377
378
379
380
381
382
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
383
384
385
386
387
            data["targets"],  # targets
            data["logit_lengths"],  # logit_lengths
            data["target_lengths"],  # target_lengths
            data["blank"],  # blank
            -1,  # clamp
388
389
390
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)