autograd_impl.py 6.46 KB
Newer Older
1
from typing import Callable, Tuple
2
import torch
3
4
from parameterized import parameterized
from torch import Tensor
5
import torchaudio.functional as F
6
from torch.autograd import gradcheck, gradgradcheck
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
)


class Autograd(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
29
        assert gradgradcheck(transform, inputs_)
30

31
    def test_lfilter_x(self):
32
        torch.random.manual_seed(2434)
33
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
34
35
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
36
        x.requires_grad = True
37
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
38

39
    def test_lfilter_a(self):
40
        torch.random.manual_seed(2434)
41
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
42
43
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
44
        a.requires_grad = True
45
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
46

47
    def test_lfilter_b(self):
48
        torch.random.manual_seed(2434)
49
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
50
51
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
52
        b.requires_grad = True
53
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
54

55
    def test_lfilter_all_inputs(self):
56
        torch.random.manual_seed(2434)
57
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
58
59
60
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
61

62
63
64
65
66
67
68
69
70
    def test_lfilter_filterbanks(self):
        torch.random.manual_seed(2434)
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
        a = torch.tensor([[0.7, 0.2, 0.6],
                          [0.8, 0.2, 0.9]])
        b = torch.tensor([[0.4, 0.2, 0.9],
                          [0.7, 0.2, 0.6]])
        self.assert_grad(F.lfilter, (x, a, b))

71
72
    def test_biquad(self):
        torch.random.manual_seed(2434)
73
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
74
75
76
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
77

78
79
80
81
82
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_band_biquad(self, central_freq, Q, noise):
83
84
        torch.random.manual_seed(2434)
        sr = 22050
85
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
86
87
88
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
89

90
91
92
93
94
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_bass_biquad(self, central_freq, Q, gain):
95
96
        torch.random.manual_seed(2434)
        sr = 22050
97
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
98
99
100
101
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
102

103
104
105
    @parameterized.expand([
        (3000, 0.7, 10),
        (3000, 0.7, -10),
106

107
108
    ])
    def test_treble_biquad(self, central_freq, Q, gain):
109
110
        torch.random.manual_seed(2434)
        sr = 22050
111
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
112
113
114
115
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
116

117
118
119
120
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_allpass_biquad(self, central_freq, Q):
121
122
        torch.random.manual_seed(2434)
        sr = 22050
123
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
124
125
126
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
127

128
129
130
131
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_lowpass_biquad(self, cutoff_freq, Q):
132
133
        torch.random.manual_seed(2434)
        sr = 22050
134
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
135
136
137
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
138

139
140
141
142
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_highpass_biquad(self, cutoff_freq, Q):
143
144
        torch.random.manual_seed(2434)
        sr = 22050
145
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
146
147
148
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
149

150
151
152
153
154
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
155
156
        torch.random.manual_seed(2434)
        sr = 22050
157
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
158
159
160
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
161

162
163
164
165
166
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_equalizer_biquad(self, central_freq, Q, gain):
167
168
        torch.random.manual_seed(2434)
        sr = 22050
169
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
170
171
172
173
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
174

175
176
177
178
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_bandreject_biquad(self, central_freq, Q):
179
180
        torch.random.manual_seed(2434)
        sr = 22050
181
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
182
183
184
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))