autograd_impl.py 11.9 KB
Newer Older
1
from functools import partial
2
3
from typing import Callable, Tuple

4
import torch
5
import torchaudio.functional as F
6
7
from parameterized import parameterized
from torch import Tensor
8
from torch.autograd import gradcheck, gradgradcheck
9
10
11
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
12
    rnnt_utils,
13
14
15
16
17
)


class Autograd(TestBaseMixin):
    def assert_grad(
18
19
20
21
22
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        *,
        enable_all_grad: bool = True,
23
24
25
26
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
27
                i = i.to(dtype=self.complex_dtype if i.is_complex() else self.dtype, device=self.device)
28
29
30
31
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
32
        assert gradgradcheck(transform, inputs_)
33

34
    def test_lfilter_x(self):
35
        torch.random.manual_seed(2434)
36
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
37
38
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
39
        x.requires_grad = True
40
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
41

42
    def test_lfilter_a(self):
43
        torch.random.manual_seed(2434)
44
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
45
46
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
47
        a.requires_grad = True
48
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
49

50
    def test_lfilter_b(self):
51
        torch.random.manual_seed(2434)
52
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
53
54
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
55
        b.requires_grad = True
56
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
57

58
    def test_lfilter_all_inputs(self):
59
        torch.random.manual_seed(2434)
60
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
61
62
63
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
64

65
66
67
    def test_lfilter_filterbanks(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
68
69
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
70
71
72
73
74
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
75
76
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
77
78
        self.assert_grad(F.lfilter, (x, a, b))

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    def test_filtfilt_a(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        a.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_b(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        b.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_all_inputs(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.filtfilt, (x, a, b))

    def test_filtfilt_batching(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
105
106
        a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9], [0.7, 0.2, 0.6]])
107
108
        self.assert_grad(F.filtfilt, (x, a, b))

109
110
    def test_biquad(self):
        torch.random.manual_seed(2434)
111
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
112
113
114
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
115

116
117
118
119
120
121
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
122
    def test_band_biquad(self, central_freq, Q, noise):
123
124
        torch.random.manual_seed(2434)
        sr = 22050
125
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
126
127
128
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
129

130
131
132
133
134
135
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
136
    def test_bass_biquad(self, central_freq, Q, gain):
137
138
        torch.random.manual_seed(2434)
        sr = 22050
139
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
140
141
142
143
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
144

145
146
147
148
149
150
    @parameterized.expand(
        [
            (3000, 0.7, 10),
            (3000, 0.7, -10),
        ]
    )
151
    def test_treble_biquad(self, central_freq, Q, gain):
152
153
        torch.random.manual_seed(2434)
        sr = 22050
154
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
155
156
157
158
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
159

160
161
162
163
164
165
166
167
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
168
    def test_allpass_biquad(self, central_freq, Q):
169
170
        torch.random.manual_seed(2434)
        sr = 22050
171
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
172
173
174
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
175

176
177
178
179
180
181
182
183
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
184
    def test_lowpass_biquad(self, cutoff_freq, Q):
185
186
        torch.random.manual_seed(2434)
        sr = 22050
187
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
188
189
190
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
191

192
193
194
195
196
197
198
199
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
200
    def test_highpass_biquad(self, cutoff_freq, Q):
201
202
        torch.random.manual_seed(2434)
        sr = 22050
203
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
204
205
206
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
207

208
209
210
211
212
213
    @parameterized.expand(
        [
            (800, 0.7, True),
            (800, 0.7, False),
        ]
    )
214
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
215
216
        torch.random.manual_seed(2434)
        sr = 22050
217
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
218
219
220
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
221

222
223
224
225
226
227
    @parameterized.expand(
        [
            (800, 0.7, 10),
            (800, 0.7, -10),
        ]
    )
228
    def test_equalizer_biquad(self, central_freq, Q, gain):
229
230
        torch.random.manual_seed(2434)
        sr = 22050
231
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
232
233
234
235
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
236

237
238
239
240
241
242
243
244
    @parameterized.expand(
        [
            (
                800,
                0.7,
            ),
        ]
    )
245
    def test_bandreject_biquad(self, central_freq, Q):
246
247
        torch.random.manual_seed(2434)
        sr = 22050
248
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
249
250
251
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
    @parameterized.expand(
        [
            (True,),
            (False,),
        ]
    )
    def test_psd(self, use_mask):
        torch.random.manual_seed(2434)
        specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
        if use_mask:
            mask = torch.rand(10, 5)
        else:
            mask = None
        self.assert_grad(F.psd, (specgram, mask))

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    def test_mvdr_weights_souden(self):
        torch.random.manual_seed(2434)
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))

    def test_mvdr_weights_souden_with_tensor(self):
        torch.random.manual_seed(2434)
        channel = 4
        n_fft_bin = 5
        psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
        reference_channel = torch.zeros(channel)
        reference_channel[0].fill_(1)
        self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    def test_mvdr_weights_rtf(self):
        torch.random.manual_seed(2434)
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))

    def test_mvdr_weights_rtf_with_tensor(self):
        torch.random.manual_seed(2434)
        batch_size = 2
        channel = 4
        n_fft_bin = 10
        rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
        psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
        reference_channel = torch.zeros(batch_size, channel)
        reference_channel[..., 0].fill_(1)
        self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))

306
307
308

class AutogradFloat32(TestBaseMixin):
    def assert_grad(
309
310
311
312
        self,
        transform: Callable[..., Tensor],
        inputs: Tuple[torch.Tensor],
        enable_all_grad: bool = True,
313
314
315
316
317
318
319
320
321
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
322
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.0)
323

324
325
326
327
328
329
330
    @parameterized.expand(
        [
            (rnnt_utils.get_B1_T10_U3_D4_data,),
            (rnnt_utils.get_B2_T4_U3_D3_data,),
            (rnnt_utils.get_B1_T2_U3_D5_data,),
        ]
    )
331
332
333
334
335
336
337
338
339
340
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
341
342
343
344
345
            data["targets"],  # targets
            data["logit_lengths"],  # logit_lengths
            data["target_lengths"],  # target_lengths
            data["blank"],  # blank
            -1,  # clamp
346
347
348
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)