autograd_impl.py 9.63 KB
Newer Older
1
from typing import Callable, Tuple
2
from functools import partial
3
import torch
4
5
from parameterized import parameterized
from torch import Tensor
6
import torchaudio.functional as F
7
from torch.autograd import gradcheck, gradgradcheck
8
9
10
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
11
    rnnt_utils,
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
)


class Autograd(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
31
        assert gradgradcheck(transform, inputs_)
32

33
    def test_lfilter_x(self):
34
        torch.random.manual_seed(2434)
35
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
36
37
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
38
        x.requires_grad = True
39
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
40

41
    def test_lfilter_a(self):
42
        torch.random.manual_seed(2434)
43
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
44
45
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
46
        a.requires_grad = True
47
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
48

49
    def test_lfilter_b(self):
50
        torch.random.manual_seed(2434)
51
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
52
53
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
54
        b.requires_grad = True
55
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
56

57
    def test_lfilter_all_inputs(self):
58
        torch.random.manual_seed(2434)
59
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
60
61
62
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
63

64
65
66
    def test_lfilter_filterbanks(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
67
68
69
70
71
72
73
74
75
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
76
77
78
79
80
81
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(F.lfilter, (x, a, b))

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def test_filtfilt_a(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        a.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_b(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        b.requires_grad = True
        self.assert_grad(F.filtfilt, (x, a, b), enable_all_grad=False)

    def test_filtfilt_all_inputs(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.filtfilt, (x, a, b))

    def test_filtfilt_batching(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(F.filtfilt, (x, a, b))

114
115
    def test_biquad(self):
        torch.random.manual_seed(2434)
116
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
117
118
119
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
120

121
122
123
124
125
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_band_biquad(self, central_freq, Q, noise):
126
127
        torch.random.manual_seed(2434)
        sr = 22050
128
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
129
130
131
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
132

133
134
135
136
137
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_bass_biquad(self, central_freq, Q, gain):
138
139
        torch.random.manual_seed(2434)
        sr = 22050
140
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
141
142
143
144
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
145

146
147
148
    @parameterized.expand([
        (3000, 0.7, 10),
        (3000, 0.7, -10),
149

150
151
    ])
    def test_treble_biquad(self, central_freq, Q, gain):
152
153
        torch.random.manual_seed(2434)
        sr = 22050
154
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
155
156
157
158
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
159

160
161
162
163
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_allpass_biquad(self, central_freq, Q):
164
165
        torch.random.manual_seed(2434)
        sr = 22050
166
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
167
168
169
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
170

171
172
173
174
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_lowpass_biquad(self, cutoff_freq, Q):
175
176
        torch.random.manual_seed(2434)
        sr = 22050
177
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
178
179
180
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
181

182
183
184
185
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_highpass_biquad(self, cutoff_freq, Q):
186
187
        torch.random.manual_seed(2434)
        sr = 22050
188
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
189
190
191
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
192

193
194
195
196
197
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
198
199
        torch.random.manual_seed(2434)
        sr = 22050
200
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
201
202
203
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
204

205
206
207
208
209
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_equalizer_biquad(self, central_freq, Q, gain):
210
211
        torch.random.manual_seed(2434)
        sr = 22050
212
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
213
214
215
216
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
217

218
219
220
221
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_bandreject_biquad(self, central_freq, Q):
222
223
        torch.random.manual_seed(2434)
        sr = 22050
224
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
225
226
227
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269


class AutogradFloat32(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)

    @parameterized.expand([
        (rnnt_utils.get_B1_T10_U3_D4_data, ),
        (rnnt_utils.get_B2_T4_U3_D3_data, ),
        (rnnt_utils.get_B1_T2_U3_D5_data, ),
    ])
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
            data["targets"],                # targets
            data["logit_lengths"],          # logit_lengths
            data["target_lengths"],         # target_lengths
            data["blank"],                  # blank
            -1,                             # clamp
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)