test_transforms.py 12.9 KB
Newer Older
1
from __future__ import print_function
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
8
9
10
import torchaudio.transforms as transforms
import unittest

11
12
13
14
15
16
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
17

David Pollack's avatar
David Pollack committed
18
19
class Tester(unittest.TestCase):

20
    # create a sinewave signal for testing
David Pollack's avatar
David Pollack committed
21
22
    sr = 16000
    freq = 440
23
    volume = .3
24
    sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
25
    sig.unsqueeze_(1)  # (64000, 1)
Soumith Chintala's avatar
Soumith Chintala committed
26
    sig = (sig * volume * 2**31).long()
27
28
29
30
    # file for stereo stft test
    test_dirpath = os.path.dirname(os.path.realpath(__file__))
    test_filepath = os.path.join(test_dirpath, "assets",
                                 "steam-train-whistle-daniel_simon.mp3")
David Pollack's avatar
David Pollack committed
31
32
33
34
35

    def test_scale(self):

        audio_orig = self.sig.clone()
        result = transforms.Scale()(audio_orig)
36
        self.assertTrue(result.min() >= -1. and result.max() <= 1.)
David Pollack's avatar
David Pollack committed
37

38
        maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
David Pollack's avatar
David Pollack committed
39
40
        result = transforms.Scale(factor=maxminmax)(audio_orig)
        self.assertTrue((result.min() == -1. or result.max() == 1.) and
41
                        result.min() >= -1. and result.max() <= 1.)
David Pollack's avatar
David Pollack committed
42

43
        repr_test = transforms.Scale()
44
        self.assertTrue(repr_test.__repr__())
45

David Pollack's avatar
David Pollack committed
46
47
48
49
50
51
    def test_pad_trim(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)

52
53
        result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
        self.assertEqual(result.size(0), length_new)
David Pollack's avatar
David Pollack committed
54

55
56
57
        result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
        self.assertEqual(result.size(1), length_new)

David Pollack's avatar
David Pollack committed
58
59
60
61
        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 0.8)

62
        result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
David Pollack's avatar
David Pollack committed
63

64
        self.assertEqual(result.size(0), length_new)
David Pollack's avatar
David Pollack committed
65

66
        repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
67
        self.assertTrue(repr_test.__repr__())
68

David Pollack's avatar
David Pollack committed
69
    def test_downmix_mono(self):
David Pollack's avatar
David Pollack committed
70

David Pollack's avatar
David Pollack committed
71
72
73
74
75
76
77
78
79
        audio_L = self.sig.clone()
        audio_R = self.sig.clone()
        R_idx = int(audio_R.size(0) * 0.1)
        audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))

        audio_Stereo = torch.cat((audio_L, audio_R), dim=1)

        self.assertTrue(audio_Stereo.size(1) == 2)

80
        result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
David Pollack's avatar
David Pollack committed
81
82
83

        self.assertTrue(result.size(1) == 1)

84
        repr_test = transforms.DownmixMono(channels_first=False)
85
        self.assertTrue(repr_test.__repr__())
86

87
88
89
90
91
92
    def test_lc2cl(self):

        audio = self.sig.clone()
        result = transforms.LC2CL()(audio)
        self.assertTrue(result.size()[::-1] == audio.size())

93
        repr_test = transforms.LC2CL()
94
        self.assertTrue(repr_test.__repr__())
95

David Pollack's avatar
David Pollack committed
96
97
98
99
100
    def test_compose(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)
101
        maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
David Pollack's avatar
David Pollack committed
102
103

        tset = (transforms.Scale(factor=maxminmax),
104
                transforms.PadTrim(max_len=length_new, channels_first=False))
David Pollack's avatar
David Pollack committed
105
106
        result = transforms.Compose(tset)(audio_orig)

107
        self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
David Pollack's avatar
David Pollack committed
108
109
110

        self.assertTrue(result.size(0) == length_new)

111
        repr_test = transforms.Compose(tset)
112
        self.assertTrue(repr_test.__repr__())
113

David Pollack's avatar
David Pollack committed
114
115
116
117
118
119
120
121
122
123
124
125
126
    def test_mu_law_companding(self):

        quantization_channels = 256

        sig = self.sig.clone()
        sig = sig / torch.abs(sig).max()
        self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)

        sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
        self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)

        sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
        self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
David Pollack's avatar
David Pollack committed
127

128
        repr_test = transforms.MuLawEncoding(quantization_channels)
129
        self.assertTrue(repr_test.__repr__())
130
        repr_test = transforms.MuLawExpanding(quantization_channels)
131
        self.assertTrue(repr_test.__repr__())
132

133
    def test_mel2(self):
PCerles's avatar
PCerles committed
134
135
136
        top_db = 80.
        s2db = transforms.SpectrogramToDB("power", top_db)

137
138
139
        audio_orig = self.sig.clone()  # (16000, 1)
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
140
        mel_transform = transforms.MelSpectrogram()
141
        # check defaults
PCerles's avatar
PCerles committed
142
        spectrogram_torch = s2db(mel_transform(audio_scaled))  # (1, 319, 40)
143
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
144
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
145
        self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
146
147
148
149
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
        # check options
150
151
        kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50}
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
PCerles's avatar
PCerles committed
152
        spectrogram2_torch = s2db(mel_transform2(audio_scaled))  # (1, 506, 50)
153
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
154
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
155
156
157
158
        self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
159
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
PCerles's avatar
PCerles committed
160
        spectrogram_stereo = s2db(mel_transform(x_stereo))
161
162
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
163
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
164
        self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
165
        # check filterbank matrix creation
166
        fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
167
168
169
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
170

PCerles's avatar
PCerles committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def test_mfcc(self):
        audio_orig = self.sig.clone()
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[1] == 321)
        # check melkwargs are passed through
        melkwargs = {'ws': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)
        self.assertTrue(torch_mfcc2.shape[1] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)

        norm_check = torch_mfcc.clone()
203
204
        norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
        norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
205
206
207

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

208
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
209
    def test_librosa_consistency(self):
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
            sound_librosa = sound.cpu().numpy().squeeze().T  # squeeze batch and channel first

            # test core spectrogram
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

            out_torch = spect_transform(sound).squeeze().cpu().t()
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
            melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
                                                                      hop=hop_length, n_mels=n_mels, n_fft=n_fft)
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
            torch_mel = melspect_transform(sound).squeeze().cpu().t()

            # lower tolerance, think it's double vs. float
            self.assertTrue(torch.allclose(torch_mel.type(torch.double), torch.from_numpy(librosa_mel), atol=5e-3))

            # test s2db

            db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t()
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)

            self.assertTrue(torch.allclose(db_torch.type(torch.double), torch.from_numpy(db_librosa), atol=5e-3))

            # test MFCC
            melkwargs = {'hop': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
            torch_mfcc = mfcc_transform(sound).squeeze().cpu().t()

            self.assertTrue(torch.allclose(torch_mfcc.type(torch.double), torch.from_numpy(librosa_mfcc), atol=5e-3))

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
304
305


David Pollack's avatar
David Pollack committed
306
307
if __name__ == '__main__':
    unittest.main()