autograd_impl.py 8.34 KB
Newer Older
1
from typing import Callable, Tuple
2
from functools import partial
3
import torch
4
5
from parameterized import parameterized
from torch import Tensor
6
import torchaudio.functional as F
7
from torch.autograd import gradcheck, gradgradcheck
8
9
10
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
11
    rnnt_utils,
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
)


class Autograd(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
31
        assert gradgradcheck(transform, inputs_)
32

33
    def test_lfilter_x(self):
34
        torch.random.manual_seed(2434)
35
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
36
37
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
38
        x.requires_grad = True
39
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
40

41
    def test_lfilter_a(self):
42
        torch.random.manual_seed(2434)
43
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
44
45
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
46
        a.requires_grad = True
47
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
48

49
    def test_lfilter_b(self):
50
        torch.random.manual_seed(2434)
51
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
52
53
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
54
        b.requires_grad = True
55
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
56

57
    def test_lfilter_all_inputs(self):
58
        torch.random.manual_seed(2434)
59
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
60
61
62
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
63

64
65
66
    def test_lfilter_filterbanks(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
67
68
69
70
71
72
73
74
75
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
76
77
78
79
80
81
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(F.lfilter, (x, a, b))

82
83
    def test_biquad(self):
        torch.random.manual_seed(2434)
84
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
85
86
87
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
88

89
90
91
92
93
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_band_biquad(self, central_freq, Q, noise):
94
95
        torch.random.manual_seed(2434)
        sr = 22050
96
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
97
98
99
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
100

101
102
103
104
105
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_bass_biquad(self, central_freq, Q, gain):
106
107
        torch.random.manual_seed(2434)
        sr = 22050
108
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
109
110
111
112
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
113

114
115
116
    @parameterized.expand([
        (3000, 0.7, 10),
        (3000, 0.7, -10),
117

118
119
    ])
    def test_treble_biquad(self, central_freq, Q, gain):
120
121
        torch.random.manual_seed(2434)
        sr = 22050
122
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
123
124
125
126
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
127

128
129
130
131
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_allpass_biquad(self, central_freq, Q):
132
133
        torch.random.manual_seed(2434)
        sr = 22050
134
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
135
136
137
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
138

139
140
141
142
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_lowpass_biquad(self, cutoff_freq, Q):
143
144
        torch.random.manual_seed(2434)
        sr = 22050
145
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
146
147
148
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
149

150
151
152
153
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_highpass_biquad(self, cutoff_freq, Q):
154
155
        torch.random.manual_seed(2434)
        sr = 22050
156
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
157
158
159
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
160

161
162
163
164
165
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
166
167
        torch.random.manual_seed(2434)
        sr = 22050
168
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
169
170
171
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
172

173
174
175
176
177
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_equalizer_biquad(self, central_freq, Q, gain):
178
179
        torch.random.manual_seed(2434)
        sr = 22050
180
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
181
182
183
184
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
185

186
187
188
189
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_bandreject_biquad(self, central_freq, Q):
190
191
        torch.random.manual_seed(2434)
        sr = 22050
192
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
193
194
195
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237


class AutogradFloat32(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        # gradcheck with float32 requires higher atol and epsilon
        assert gradcheck(transform, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.)

    @parameterized.expand([
        (rnnt_utils.get_B1_T10_U3_D4_data, ),
        (rnnt_utils.get_B2_T4_U3_D3_data, ),
        (rnnt_utils.get_B1_T2_U3_D5_data, ),
    ])
    def test_rnnt_loss(self, data_func):
        def get_data(data_func, device):
            data = data_func()
            if type(data) == tuple:
                data = data[0]
            return data

        data = get_data(data_func, self.device)
        inputs = (
            data["logits"].to(torch.float32),  # logits
            data["targets"],                # targets
            data["logit_lengths"],          # logit_lengths
            data["target_lengths"],         # target_lengths
            data["blank"],                  # blank
            -1,                             # clamp
        )

        self.assert_grad(F.rnnt_loss, inputs, enable_all_grad=False)