autograd_impl.py 6.1 KB
Newer Older
1
from typing import Callable, Tuple
2
import torch
3
4
from parameterized import parameterized
from torch import Tensor
5
import torchaudio.functional as F
6
from torch.autograd import gradcheck, gradgradcheck
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
)


class Autograd(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
29
        assert gradgradcheck(transform, inputs_)
30

31
    def test_lfilter_x(self):
32
        torch.random.manual_seed(2434)
33
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
34
35
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
36
        x.requires_grad = True
37
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
38

39
    def test_lfilter_a(self):
40
        torch.random.manual_seed(2434)
41
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
42
43
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
44
        a.requires_grad = True
45
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
46

47
    def test_lfilter_b(self):
48
        torch.random.manual_seed(2434)
49
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
50
51
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
52
        b.requires_grad = True
53
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
54

55
    def test_lfilter_all_inputs(self):
56
        torch.random.manual_seed(2434)
57
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
58
59
60
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
61
62
63

    def test_biquad(self):
        torch.random.manual_seed(2434)
64
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
65
66
67
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
68

69
70
71
72
73
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_band_biquad(self, central_freq, Q, noise):
74
75
        torch.random.manual_seed(2434)
        sr = 22050
76
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
77
78
79
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
80

81
82
83
84
85
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_bass_biquad(self, central_freq, Q, gain):
86
87
        torch.random.manual_seed(2434)
        sr = 22050
88
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
89
90
91
92
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
93

94
95
96
    @parameterized.expand([
        (3000, 0.7, 10),
        (3000, 0.7, -10),
97

98
99
    ])
    def test_treble_biquad(self, central_freq, Q, gain):
100
101
        torch.random.manual_seed(2434)
        sr = 22050
102
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
103
104
105
106
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
107

108
109
110
111
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_allpass_biquad(self, central_freq, Q):
112
113
        torch.random.manual_seed(2434)
        sr = 22050
114
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
115
116
117
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
118

119
120
121
122
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_lowpass_biquad(self, cutoff_freq, Q):
123
124
        torch.random.manual_seed(2434)
        sr = 22050
125
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
126
127
128
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
129

130
131
132
133
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_highpass_biquad(self, cutoff_freq, Q):
134
135
        torch.random.manual_seed(2434)
        sr = 22050
136
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
137
138
139
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
140

141
142
143
144
145
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
146
147
        torch.random.manual_seed(2434)
        sr = 22050
148
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
149
150
151
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
152

153
154
155
156
157
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_equalizer_biquad(self, central_freq, Q, gain):
158
159
        torch.random.manual_seed(2434)
        sr = 22050
160
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
161
162
163
164
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
165

166
167
168
169
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_bandreject_biquad(self, central_freq, Q):
170
171
        torch.random.manual_seed(2434)
        sr = 22050
172
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
173
174
175
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))