batch_consistency_test.py 13.9 KB
Newer Older
1
"""Test numerical consistency among single input and batched input."""
2
3
import itertools
from parameterized import parameterized
4

5
6
import math

7
8
9
10
import torch
import torchaudio
import torchaudio.functional as F

11
from torchaudio_unittest import common_utils
12
13


moto's avatar
moto committed
14
15
class TestFunctional(common_utils.TorchaudioTestCase):
    backend = 'default'
16
    """Test functions defined in `functional` module"""
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    def assert_batch_consistency(
            self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
        # run then batch the result
        torch.random.manual_seed(seed)
        expected = functional(tensor.clone(), *args, **kwargs)
        expected = expected.repeat([batch_size] + [1] * expected.dim())

        # batch the input and run
        torch.random.manual_seed(seed)
        pattern = [batch_size] + [1] * tensor.dim()
        computed = functional(tensor.repeat(pattern), *args, **kwargs)

        self.assertEqual(computed, expected, rtol=rtol, atol=atol)

    def assert_batch_consistencies(
            self, functional, tensor, *args, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
        self.assert_batch_consistency(
            functional, tensor, *args, batch_size=1, atol=atol, rtol=rtol, seed=seed, **kwargs)
        self.assert_batch_consistency(
            functional, tensor, *args, batch_size=3, atol=atol, rtol=rtol, seed=seed, **kwargs)

38
39
40
41
42
43
44
45
46
47
48
    def test_griffinlim(self):
        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000
        tensor = torch.rand((1, 201, 6))
49
        self.assert_batch_consistencies(
50
51
52
            F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
        )

53
54
55
56
57
58
59
60
61
    @parameterized.expand(list(itertools.product(
        [100, 440],
        [8000, 16000, 44100],
        [1, 2],
    )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
    def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels):
        waveform = common_utils.get_sinusoid(frequency=frequency, sample_rate=sample_rate,
                                             n_channels=n_channels, duration=5)
        self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def test_amplitude_to_DB(self):
        torch.manual_seed(0)
        spec = torch.rand(2, 100, 100) * 200

        amplitude_mult = 20.
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))

        # Test with & without a `top_db` clamp
        self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
                                        amin, db_mult, top_db=None)
        self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
                                        amin, db_mult, top_db=40.)

    def test_amplitude_to_DB_itemwise_clamps(self):
        """Ensure that the clamps are separate for each spectrogram in a batch.

        The clamp was determined per-batch in a prior implementation, which
        meant it was determined by the loudest item, thus items weren't
        independent. See:

        https://github.com/pytorch/audio/issues/994

        """
        amplitude_mult = 20.
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))
        top_db = 20.

        # Make a batch of noise
        torch.manual_seed(0)
        spec = torch.rand([2, 2, 100, 100]) * 200
        # Make one item blow out the other
        spec[0] += 50

        batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
                                          db_mult, top_db=top_db)
        itemwise_dbs = torch.stack([
            F.amplitude_to_DB(item, amplitude_mult, amin,
                              db_mult, top_db=top_db)
            for item in spec
        ])

        self.assertEqual(batchwise_dbs, itemwise_dbs)

    def test_amplitude_to_DB_not_channelwise_clamps(self):
        """Check that clamps are applied per-item, not per channel."""
        amplitude_mult = 20.
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))
        top_db = 40.

        torch.manual_seed(0)
        spec = torch.rand([1, 2, 100, 100]) * 200
        # Make one channel blow out the other
        spec[:, 0] += 50

        specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
                                         db_mult, top_db=top_db)
        channelwise_dbs = torch.stack([
            F.amplitude_to_DB(spec[:, i], amplitude_mult, amin,
                              db_mult, top_db=top_db)
            for i in range(spec.size(-3))
        ])

        # Just check channelwise gives a different answer.
        difference = (specwise_dbs - channelwise_dbs).abs()
        assert (difference >= 1e-5).any()

135
136
    def test_contrast(self):
        waveform = torch.rand(2, 100) - 0.5
137
        self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
138
139
140

    def test_dcshift(self):
        waveform = torch.rand(2, 100) - 0.5
141
        self.assert_batch_consistencies(F.dcshift, waveform, shift=0.5, limiter_gain=0.05)
142

143
144
    def test_overdrive(self):
        waveform = torch.rand(2, 100) - 0.5
145
        self.assert_batch_consistencies(F.overdrive, waveform, gain=45, colour=30)
146

147
    def test_phaser(self):
148
149
150
151
        sample_rate = 44100
        waveform = common_utils.get_whitenoise(
            sample_rate=sample_rate, duration=5,
        )
152
        self.assert_batch_consistencies(F.phaser, waveform, sample_rate)
153

154
155
156
157
158
159
    def test_flanger(self):
        torch.random.manual_seed(40)
        waveform = torch.rand(2, 100) - 0.5
        sample_rate = 44100
        self.assert_batch_consistencies(F.flanger, waveform, sample_rate)

160
161
    def test_sliding_window_cmn(self):
        waveform = torch.randn(2, 1024) - 0.5
162
163
164
165
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=True)
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=True, norm_vars=False)
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
        self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
Artyom Astafurov's avatar
Artyom Astafurov committed
166
167

    def test_vad(self):
moto's avatar
moto committed
168
        common_utils.set_audio_backend('default')
169
        filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
Artyom Astafurov's avatar
Artyom Astafurov committed
170
        waveform, sample_rate = torchaudio.load(filepath)
171
        self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)
172

173

moto's avatar
moto committed
174
175
176
class TestTransforms(common_utils.TorchaudioTestCase):
    backend = 'default'

177
178
    """Test suite for classes defined in `transforms` module"""
    def test_batch_AmplitudeToDB(self):
179
        spec = torch.rand((2, 6, 201))
180
181
182
183
184
185
186

        # Single then transform then batch
        expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

187
        self.assertEqual(computed, expected)
188
189
190
191
192
193
194
195
196
197

    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = torchaudio.transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Resample()(waveform.repeat(3, 1, 1))

198
        self.assertEqual(computed, expected)
199
200
201
202
203
204
205
206
207
208
209

    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = torchaudio.transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
210
        self.assertEqual(computed, expected)
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

    def test_batch_InverseMelScale(self):
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32) ** 2

        # Single then transform then batch
        expected = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
227
        self.assertEqual(computed, expected, atol=1.0, rtol=1e-5)
228
229
230
231
232
233
234
235
236
237
238

    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = torchaudio.transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
239
        self.assertEqual(computed, expected)
240
241

    def test_batch_mulaw(self):
242
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
243
244
245
246
247
248
249
250
251
252
253
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
254
        self.assertEqual(computed, expected)
255
256
257
258
259
260
261
262
263

        # Single then transform then batch
        waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
264
        self.assertEqual(computed, expected)
265
266

    def test_batch_spectrogram(self):
267
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
268
269
270
271
272
273
274
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Spectrogram()(waveform.repeat(3, 1, 1))
275
        self.assertEqual(computed, expected)
276
277

    def test_batch_melspectrogram(self):
278
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
279
280
281
282
283
284
285
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))
286
        self.assertEqual(computed, expected)
287
288

    def test_batch_mfcc(self):
moto's avatar
moto committed
289
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
290
291
292
293
294
295
296
        waveform, _ = torchaudio.load(test_filepath)

        # Single then transform then batch
        expected = torchaudio.transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
297
        self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
298
299

    def test_batch_TimeStretch(self):
300
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
301
302
303
304
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        rate = 2

305
306
307
308
309
310
311
312
313
314
315
316
317
318
        complex_specgrams = torch.view_as_real(
            torch.stft(
                input=waveform,
                n_fft=2048,
                hop_length=512,
                win_length=2048,
                window=torch.hann_window(2048),
                center=True,
                pad_mode='reflect',
                normalized=True,
                onesided=True,
                return_complex=True,
            )
        )
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams.repeat(3, 1, 1, 1, 1))

334
        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
335
336

    def test_batch_Fade(self):
337
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
338
339
340
341
342
343
344
345
346
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))
347
        self.assertEqual(computed, expected)
348
349

    def test_batch_Vol(self):
350
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
351
352
353
354
355
356
357
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
358
        self.assertEqual(computed, expected)
359
360
361
362
363
364
365
366
367
368
369

    def test_batch_spectral_centroid(self):
        sample_rate = 44100
        waveform = common_utils.get_whitenoise(sample_rate=sample_rate)

        # Single then transform then batch
        expected = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1))
        self.assertEqual(computed, expected)