autograd_impl.py 7.23 KB
Newer Older
1
2
3
4
5
6
7
import torch
import torchaudio.functional as F
from torch.autograd import gradcheck
from torchaudio_unittest import common_utils


class Autograd(common_utils.TestBaseMixin):
8
    def test_lfilter_x(self):
9
10
11
12
13
14
15
        torch.random.manual_seed(2434)
        x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
        a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
        b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
        x.requires_grad = True
        assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

16
    def test_lfilter_a(self):
17
18
19
20
21
22
23
        torch.random.manual_seed(2434)
        x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
        a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
        b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
        a.requires_grad = True
        assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

24
    def test_lfilter_b(self):
25
26
27
28
29
30
31
        torch.random.manual_seed(2434)
        x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
        a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
        b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
        b.requires_grad = True
        assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)

32
    def test_lfilter_all_inputs(self):
33
34
35
36
37
38
39
40
        torch.random.manual_seed(2434)
        x = torch.rand(2, 4, 256 * 2, dtype=self.dtype, device=self.device)
        a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device)
        b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device)
        b.requires_grad = True
        a.requires_grad = True
        x.requires_grad = True
        assert gradcheck(F.lfilter, (x, a, b), eps=1e-10)
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

    def test_biquad(self):
        torch.random.manual_seed(2434)
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        a = torch.tensor([0.7, 0.2, 0.6], dtype=self.dtype, device=self.device, requires_grad=True)
        b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2]), eps=1e-10)

    def test_band_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.band_biquad, (x, sr, central_freq, Q))

    def test_band_biquad_with_noise(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.band_biquad, (x, sr, central_freq, Q, True))

    def test_bass_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(100, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.bass_biquad, (x, sr, gain, central_freq, Q))

    def test_treble_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(3000, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.treble_biquad, (x, sr, gain, central_freq, Q))

    def test_allpass_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.allpass_biquad, (x, sr, central_freq, Q))

    def test_lowpass_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.lowpass_biquad, (x, sr, cutoff_freq, Q))

    def test_highpass_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        cutoff_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.highpass_biquad, (x, sr, cutoff_freq, Q))

    def test_bandpass_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q))

    def test_bandpass_biquad_with_const_skirt_gain(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.bandpass_biquad, (x, sr, central_freq, Q, True))

    def test_equalizer_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        gain = torch.tensor(10, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.equalizer_biquad, (x, sr, central_freq, gain, Q))

    def test_bandreject_biquad(self):
        torch.random.manual_seed(2434)
        sr = 22050
        x = torch.rand(1024, dtype=self.dtype, device=self.device, requires_grad=True)
        central_freq = torch.tensor(800, dtype=self.dtype, device=self.device, requires_grad=True)
        Q = torch.tensor(0.7, dtype=self.dtype, device=self.device, requires_grad=True)
        assert gradcheck(F.bandreject_biquad, (x, sr, central_freq, Q))