test_transforms.py 17.1 KB
Newer Older
1
import math
2
import os
3
import unittest
4

David Pollack's avatar
David Pollack committed
5
6
7
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
import torchaudio.functional as F
David Pollack's avatar
David Pollack committed
9

10
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
11
12


David Pollack's avatar
David Pollack committed
13
14
class Tester(unittest.TestCase):

15
    # create a sinewave signal for testing
16
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
17
    freq = 440
18
    volume = .3
19
20
21
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
22
    # file for stereo stft test
23
    test_dirpath, test_dir = create_temp_assets_dir()
24
    test_filepath = os.path.join(test_dirpath, 'assets',
25
                                 'steam-train-whistle-daniel_simon.wav')
David Pollack's avatar
David Pollack committed
26

27
28
29
30
31
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
32

David Pollack's avatar
David Pollack committed
33
34
35
36
    def test_mu_law_companding(self):

        quantization_channels = 256

37
38
39
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
40

41
42
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
43

44
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
45
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    def test_batch_AmplitudeToDB(self):
        spec = torch.rand((6, 201))

        # Single then transform then batch
        expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

103
    def test_mel2(self):
PCerles's avatar
PCerles committed
104
        top_db = 80.
105
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
106

107
108
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
109
        mel_transform = transforms.MelSpectrogram()
110
        # check defaults
111
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
112
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
113
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
114
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
115
        # check correctness of filterbank conversion matrix
116
117
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
118
        # check options
119
120
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
121
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
122
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
123
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
124
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
125
126
127
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
128
        # check on multi-channel audio
129
130
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
131
132
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
133
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
134
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
135
        # check filterbank matrix creation
136
137
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
138
139
140
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
141

PCerles's avatar
PCerles committed
142
    def test_mfcc(self):
143
144
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
145
146
147
148

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
149
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
150
151
152
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
153
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
154
        self.assertTrue(torch_mfcc.dim() == 3)
155
156
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
157
        # check melkwargs are passed through
158
159
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
160
161
162
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
163
164
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
165
166

        # check norms work correctly
167
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
168
169
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
170
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
171
172

        norm_check = torch_mfcc.clone()
173
174
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
175
176
177

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

Vincent QB's avatar
Vincent QB committed
178
179
180
181
182
183
184
185
186
187
188
189
    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Resample()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

jamarshon's avatar
jamarshon committed
190
191
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
192
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
193
194
195
196
197

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

198
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
199
200
201

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
202
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
203
204

        # we expect the upsampled signal to have twice as many samples
205
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
206
207
208

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
209
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
210
211

        # we expect the downsampled signal to have half as many samples
212
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
213

Vincent QB's avatar
Vincent QB committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

Vincent QB's avatar
Vincent QB committed
245
246
247
248
249
250
251
252
253
254
255
256
257
    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

moto's avatar
moto committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    def test_batch_InverseMelScale(self):
        n_fft = 8
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32) ** 2

        # Single then transform then batch
        expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
        self.assertTrue(torch.allclose(computed, expected, atol=1.0))

Vincent QB's avatar
Vincent QB committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

316
317
318
319
320
321
322
323
324
325
326
327
    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
328
329
330
331
332
333
334
335
336
337
338
339
    def test_batch_melspectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

340
341
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
Vincent QB's avatar
Vincent QB committed
342
    def test_batch_mfcc(self):
343
344
345
346
        test_filepath = os.path.join(
            self.test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3'
        )
        waveform, sample_rate = torchaudio.load(test_filepath)
Vincent QB's avatar
Vincent QB committed
347
348
349
350
351
352
353
354
355
356

        # Single then transform then batch
        expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    def test_batch_TimeStretch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

Tomás Osório's avatar
Tomás Osório committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    def test_batch_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Tomás Osório's avatar
Tomás Osório committed
401
402
403
404
405
406
407
408
409
410
411
412
    def test_batch_Vol(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
413

David Pollack's avatar
David Pollack committed
414
415
if __name__ == '__main__':
    unittest.main()