batch_consistency_test.py 10.9 KB
Newer Older
1
2
"""Test numerical consistency among single input and batched input."""
import unittest
3
4
import itertools
from parameterized import parameterized
5
6
7
8
9

import torch
import torchaudio
import torchaudio.functional as F

10
from torchaudio_unittest import common_utils
11
12


moto's avatar
moto committed
13
14
class TestFunctional(common_utils.TorchaudioTestCase):
    backend = 'default'
15
    """Test functions defined in `functional` module"""
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    def assert_batch_consistency(
            self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
        # run then batch the result
        torch.random.manual_seed(seed)
        expected = functional(tensor.clone(), *args, **kwargs)
        expected = expected.repeat([batch_size] + [1] * expected.dim())

        # batch the input and run
        torch.random.manual_seed(seed)
        pattern = [batch_size] + [1] * tensor.dim()
        computed = functional(tensor.repeat(pattern), *args, **kwargs)

        self.assertEqual(computed, expected, rtol=rtol, atol=atol)

    def assert_batch_consistencies(
            self, functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
        self.assert_batch_consistency(
            functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
        self.assert_batch_consistency(
            functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)

37
38
39
40
41
42
43
44
45
46
47
    def test_griffinlim(self):
        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000
        tensor = torch.rand((1, 201, 6))
48
        self.assert_batch_consistencies(
49
50
51
            F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
        )

52
53
54
55
56
57
58
59
60
    @parameterized.expand(list(itertools.product(
        [100, 440],
        [8000, 16000, 44100],
        [1, 2],
    )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
    def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels):
        waveform = common_utils.get_sinusoid(frequency=frequency, sample_rate=sample_rate,
                                             n_channels=n_channels, duration=5)
        self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
61

62
63
    def test_contrast(self):
        waveform = torch.rand(2, 100) - 0.5
64
        self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
65
66
67

    def test_dcshift(self):
        waveform = torch.rand(2, 100) - 0.5
68
        self.assert_batch_consistencies(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
69

70
71
    def test_overdrive(self):
        waveform = torch.rand(2, 100) - 0.5
72
        self.assert_batch_consistencies(F.overdrive, waveform, gain=45, colour=30)
73

74
    def test_phaser(self):
75
76
77
78
        sample_rate = 44100
        waveform = common_utils.get_whitenoise(
            sample_rate=sample_rate, duration=5,
        )
79
        self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
80

81
82
83
84
85
86
    def test_flanger(self):
        torch.random.manual_seed(40)
        waveform = torch.rand(2, 100) - 0.5
        sample_rate = 44100
        self.assert_batch_consistencies(F.flanger, waveform, sample_rate)

87
88
    def test_sliding_window_cmn(self):
        waveform = torch.randn(2, 1024) - 0.5
89
90
91
92
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
Artyom Astafurov's avatar
Artyom Astafurov committed
93
94

    def test_vad(self):
moto's avatar
moto committed
95
        common_utils.set_audio_backend('default')
96
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
Artyom Astafurov's avatar
Artyom Astafurov committed
97
        waveform, sample_rate = torchaudio.load(filepath)
98
        self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)
99

100

moto's avatar
moto committed
101
102
103
class TestTransforms(common_utils.TorchaudioTestCase):
    backend = 'default'

104
105
106
107
108
109
110
111
112
113
    """Test suite for classes defined in `transforms` module"""
    def test_batch_AmplitudeToDB(self):
        spec = torch.rand((6, 201))

        # Single then transform then batch
        expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

114
        self.assertEqual(computed, expected)
115
116
117
118
119
120
121
122
123
124

    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = torchaudio.transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))

125
        self.assertEqual(computed, expected)
126
127
128
129
130
131
132
133
134
135
136

    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
137
        self.assertEqual(computed, expected)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    def test_batch_InverseMelScale(self):
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32) ** 2

        # Single then transform then batch
        expected = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
154
        self.assertEqual(computed, expected, atol=1.0, rtol=1e-5)
155
156
157
158
159
160
161
162
163
164
165

    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = torchaudio.transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
166
        self.assertEqual(computed, expected)
167
168

    def test_batch_mulaw(self):
169
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
170
171
172
173
174
175
176
177
178
179
180
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
181
        self.assertEqual(computed, expected)
182
183
184
185
186
187
188
189
190

        # Single then transform then batch
        waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
191
        self.assertEqual(computed, expected)
192
193

    def test_batch_spectrogram(self):
194
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
195
196
197
198
199
200
201
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
202
        self.assertEqual(computed, expected)
203
204

    def test_batch_melspectrogram(self):
205
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
206
207
208
209
210
211
212
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
213
        self.assertEqual(computed, expected)
214
215

    def test_batch_mfcc(self):
moto's avatar
moto committed
216
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
217
218
219
220
221
222
223
        waveform, _ = torchaudio.load(test_filepath)

        # Single then transform then batch
        expected = torchaudio.transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
224
        self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
225
226

    def test_batch_TimeStretch(self):
227
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
228
229
230
231
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        rate = 2

232
233
234
235
236
237
238
239
240
241
242
243
244
245
        complex_specgrams = torch.view_as_real(
            torch.stft(
                input=waveform,
                n_fft=2048,
                hop_length=512,
                win_length=2048,
                window=torch.hann_window(2048),
                center=True,
                pad_mode='reflect',
                normalized=True,
                onesided=True,
                return_complex=True,
            )
        )
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams.repeat(3, 1, 1, 1, 1))

261
        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
262
263

    def test_batch_Fade(self):
264
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
265
266
267
268
269
270
271
272
273
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
274
        self.assertEqual(computed, expected)
275
276

    def test_batch_Vol(self):
277
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
278
279
280
281
282
283
284
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
285
        self.assertEqual(computed, expected)