"vscode:/vscode.git/clone" did not exist on "ba02a6742cb56919b09f80e058e197b27a72a0af"
test_transforms.py 14.7 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
7
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
9
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
10
import unittest
11
import common_utils
David Pollack's avatar
David Pollack committed
12

13
14
15
16
17
18
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
19

David Pollack's avatar
David Pollack committed
20
21
class Tester(unittest.TestCase):

22
    # create a sinewave signal for testing
23
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
24
    freq = 440
25
    volume = .3
26
27
28
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
29
    # file for stereo stft test
30
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
31
32
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
33

34
35
36
37
38
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
39

David Pollack's avatar
David Pollack committed
40
41
42
43
    def test_mu_law_companding(self):

        quantization_channels = 256

44
45
46
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
47

48
49
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
50

51
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
52
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

87
    def test_mel2(self):
PCerles's avatar
PCerles committed
88
        top_db = 80.
89
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
90

91
92
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
93
        mel_transform = transforms.MelSpectrogram()
94
        # check defaults
95
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
96
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
97
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
98
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
99
        # check correctness of filterbank conversion matrix
100
101
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
102
        # check options
103
104
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
105
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
106
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
107
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
108
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
109
110
111
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
112
        # check on multi-channel audio
113
114
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
115
116
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
117
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
118
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
119
        # check filterbank matrix creation
120
121
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
122
123
124
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
125

PCerles's avatar
PCerles committed
126
    def test_mfcc(self):
127
128
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
129
130
131
132

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
133
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
134
135
136
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
137
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
138
        self.assertTrue(torch_mfcc.dim() == 3)
139
140
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
141
        # check melkwargs are passed through
142
143
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
144
145
146
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
147
148
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
149
150

        # check norms work correctly
151
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
152
153
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
154
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
155
156

        norm_check = torch_mfcc.clone()
157
158
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
159
160
161

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

162
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
163
    def test_librosa_consistency(self):
164
165
166
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
167
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
168
169

            # test core spectrogram
170
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
171
172
173
174
175
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

176
            out_torch = spect_transform(sound).squeeze().cpu()
177
178
179
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
180
181
182
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
183
184
185
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
186
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
187
            torch_mel = melspect_transform(sound).squeeze().cpu()
188

jamarshon's avatar
jamarshon committed
189
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
190
191

            # test s2db
192
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
193
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
194
195
196
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

197
            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
198
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
199
            db_librosa_tensor = torch.from_numpy(db_librosa)
200

jamarshon's avatar
jamarshon committed
201
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
202
203

            # test MFCC
204
205
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
225
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
226
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
227

jamarshon's avatar
jamarshon committed
228
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
260

jamarshon's avatar
jamarshon committed
261
262
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
263
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
264
265
266
267
268

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

269
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
270
271
272

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
273
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
274
275

        # we expect the upsampled signal to have twice as many samples
276
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
277
278
279

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
280
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
281
282

        # we expect the downsampled signal to have half as many samples
283
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
284

Vincent QB's avatar
Vincent QB committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

316
317
318
319
320
321
322
323
324
325
326
327
    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

Vincent QB's avatar
Vincent QB committed
328

David Pollack's avatar
David Pollack committed
329
330
if __name__ == '__main__':
    unittest.main()