autograd_impl.py 6.88 KB
Newer Older
1
from typing import Callable, Tuple
2
from functools import partial
3
import torch
4
5
from parameterized import parameterized
from torch import Tensor
6
import torchaudio.functional as F
7
from torch.autograd import gradcheck, gradgradcheck
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
)


class Autograd(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
30
        assert gradgradcheck(transform, inputs_)
31

32
    def test_lfilter_x(self):
33
        torch.random.manual_seed(2434)
34
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
35
36
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
37
        x.requires_grad = True
38
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
39

40
    def test_lfilter_a(self):
41
        torch.random.manual_seed(2434)
42
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
43
44
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
45
        a.requires_grad = True
46
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
47

48
    def test_lfilter_b(self):
49
        torch.random.manual_seed(2434)
50
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
51
52
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
53
        b.requires_grad = True
54
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
55

56
    def test_lfilter_all_inputs(self):
57
        torch.random.manual_seed(2434)
58
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
59
60
61
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
62

63
64
65
    def test_lfilter_filterbanks(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
66
67
68
69
70
71
72
73
74
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

    def test_lfilter_batching(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
75
76
77
78
79
80
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(F.lfilter, (x, a, b))

81
82
    def test_biquad(self):
        torch.random.manual_seed(2434)
83
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
84
85
86
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
87

88
89
90
91
92
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_band_biquad(self, central_freq, Q, noise):
93
94
        torch.random.manual_seed(2434)
        sr = 22050
95
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
96
97
98
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
99

100
101
102
103
104
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_bass_biquad(self, central_freq, Q, gain):
105
106
        torch.random.manual_seed(2434)
        sr = 22050
107
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
108
109
110
111
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
112

113
114
115
    @parameterized.expand([
        (3000, 0.7, 10),
        (3000, 0.7, -10),
116

117
118
    ])
    def test_treble_biquad(self, central_freq, Q, gain):
119
120
        torch.random.manual_seed(2434)
        sr = 22050
121
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
122
123
124
125
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
126

127
128
129
130
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_allpass_biquad(self, central_freq, Q):
131
132
        torch.random.manual_seed(2434)
        sr = 22050
133
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
134
135
136
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
137

138
139
140
141
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_lowpass_biquad(self, cutoff_freq, Q):
142
143
        torch.random.manual_seed(2434)
        sr = 22050
144
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
145
146
147
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
148

149
150
151
152
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_highpass_biquad(self, cutoff_freq, Q):
153
154
        torch.random.manual_seed(2434)
        sr = 22050
155
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
156
157
158
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
159

160
161
162
163
164
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
165
166
        torch.random.manual_seed(2434)
        sr = 22050
167
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
168
169
170
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
171

172
173
174
175
176
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_equalizer_biquad(self, central_freq, Q, gain):
177
178
        torch.random.manual_seed(2434)
        sr = 22050
179
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
180
181
182
183
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
184

185
186
187
188
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_bandreject_biquad(self, central_freq, Q):
189
190
        torch.random.manual_seed(2434)
        sr = 22050
191
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
192
193
194
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))