autograd_impl.py 6.03 KB
Newer Older
1
from typing import Callable, Tuple
2
import torch
3
4
from parameterized import parameterized
from torch import Tensor
5
6
import torchaudio.functional as F
from torch.autograd import gradcheck
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_whitenoise,
)


class Autograd(TestBaseMixin):
    def assert_grad(
            self,
            transform: Callable[..., Tensor],
            inputs: Tuple[torch.Tensor],
            *,
            enable_all_grad: bool = True,
    ):
        inputs_ = []
        for i in inputs:
            if torch.is_tensor(i):
                i = i.to(dtype=self.dtype, device=self.device)
                if enable_all_grad:
                    i.requires_grad = True
            inputs_.append(i)
        assert gradcheck(transform, inputs_)
29

30
    def test_lfilter_x(self):
31
        torch.random.manual_seed(2434)
32
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
33
34
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
35
        x.requires_grad = True
36
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
37

38
    def test_lfilter_a(self):
39
        torch.random.manual_seed(2434)
40
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
41
42
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
43
        a.requires_grad = True
44
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
45

46
    def test_lfilter_b(self):
47
        torch.random.manual_seed(2434)
48
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
49
50
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
51
        b.requires_grad = True
52
        self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False)
53

54
    def test_lfilter_all_inputs(self):
55
        torch.random.manual_seed(2434)
56
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
57
58
59
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.lfilter, (x, a, b))
60
61
62

    def test_biquad(self):
        torch.random.manual_seed(2434)
63
        x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
64
65
66
        a = torch.tensor([0.7, 0.2, 0.6])
        b = torch.tensor([0.4, 0.2, 0.9])
        self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]))
67

68
69
70
71
72
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_band_biquad(self, central_freq, Q, noise):
73
74
        torch.random.manual_seed(2434)
        sr = 22050
75
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
76
77
78
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise))
79

80
81
82
83
84
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_bass_biquad(self, central_freq, Q, gain):
85
86
        torch.random.manual_seed(2434)
        sr = 22050
87
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
88
89
90
91
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.bass_biquad, (x, sr, gain, central_freq, Q))
92

93
94
95
    @parameterized.expand([
        (3000, 0.7, 10),
        (3000, 0.7, -10),
96

97
98
    ])
    def test_treble_biquad(self, central_freq, Q, gain):
99
100
        torch.random.manual_seed(2434)
        sr = 22050
101
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
102
103
104
105
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.treble_biquad, (x, sr, gain, central_freq, Q))
106

107
108
109
110
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_allpass_biquad(self, central_freq, Q):
111
112
        torch.random.manual_seed(2434)
        sr = 22050
113
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
114
115
116
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q))
117

118
119
120
121
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_lowpass_biquad(self, cutoff_freq, Q):
122
123
        torch.random.manual_seed(2434)
        sr = 22050
124
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
125
126
127
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q))
128

129
130
131
132
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_highpass_biquad(self, cutoff_freq, Q):
133
134
        torch.random.manual_seed(2434)
        sr = 22050
135
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
136
137
138
        cutoff_freq = torch.tensor(cutoff_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q))
139

140
141
142
143
144
    @parameterized.expand([
        (800, 0.7, True),
        (800, 0.7, False),
    ])
    def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain):
145
146
        torch.random.manual_seed(2434)
        sr = 22050
147
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
148
149
150
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain))
151

152
153
154
155
156
    @parameterized.expand([
        (800, 0.7, 10),
        (800, 0.7, -10),
    ])
    def test_equalizer_biquad(self, central_freq, Q, gain):
157
158
        torch.random.manual_seed(2434)
        sr = 22050
159
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
160
161
162
163
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        gain = torch.tensor(gain)
        self.assert_grad(F.equalizer_biquad, (x, sr, central_freq, gain, Q))
164

165
166
167
168
    @parameterized.expand([
        (800, 0.7, ),
    ])
    def test_bandreject_biquad(self, central_freq, Q):
169
170
        torch.random.manual_seed(2434)
        sr = 22050
171
        x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1)
172
173
174
        central_freq = torch.tensor(central_freq)
        Q = torch.tensor(Q)
        self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))