test_transforms.py 12.9 KB
Newer Older
1
from __future__ import print_function
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
8
9
import torchaudio.transforms as transforms
import unittest
10
import test.common_utils
David Pollack's avatar
David Pollack committed
11

12
13
14
15
16
17
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
18

David Pollack's avatar
David Pollack committed
19
20
class Tester(unittest.TestCase):

21
    # create a sinewave signal for testing
David Pollack's avatar
David Pollack committed
22
23
    sr = 16000
    freq = 440
24
    volume = .3
25
    sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
26
    sig.unsqueeze_(1)  # (64000, 1)
Soumith Chintala's avatar
Soumith Chintala committed
27
    sig = (sig * volume * 2**31).long()
28
    # file for stereo stft test
29
    test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
30
31
    test_filepath = os.path.join(test_dirpath, "assets",
                                 "steam-train-whistle-daniel_simon.mp3")
David Pollack's avatar
David Pollack committed
32
33
34
35
36

    def test_scale(self):

        audio_orig = self.sig.clone()
        result = transforms.Scale()(audio_orig)
37
        self.assertTrue(result.min() >= -1. and result.max() <= 1.)
David Pollack's avatar
David Pollack committed
38

39
        maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
David Pollack's avatar
David Pollack committed
40
41
        result = transforms.Scale(factor=maxminmax)(audio_orig)
        self.assertTrue((result.min() == -1. or result.max() == 1.) and
42
                        result.min() >= -1. and result.max() <= 1.)
David Pollack's avatar
David Pollack committed
43

44
        repr_test = transforms.Scale()
45
        self.assertTrue(repr_test.__repr__())
46

David Pollack's avatar
David Pollack committed
47
48
49
50
51
52
    def test_pad_trim(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)

53
54
        result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
        self.assertEqual(result.size(0), length_new)
David Pollack's avatar
David Pollack committed
55

56
57
58
        result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
        self.assertEqual(result.size(1), length_new)

David Pollack's avatar
David Pollack committed
59
60
61
62
        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 0.8)

63
        result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
David Pollack's avatar
David Pollack committed
64

65
        self.assertEqual(result.size(0), length_new)
David Pollack's avatar
David Pollack committed
66

67
        repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
68
        self.assertTrue(repr_test.__repr__())
69

David Pollack's avatar
David Pollack committed
70
    def test_downmix_mono(self):
David Pollack's avatar
David Pollack committed
71

David Pollack's avatar
David Pollack committed
72
73
74
75
76
77
78
79
80
        audio_L = self.sig.clone()
        audio_R = self.sig.clone()
        R_idx = int(audio_R.size(0) * 0.1)
        audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))

        audio_Stereo = torch.cat((audio_L, audio_R), dim=1)

        self.assertTrue(audio_Stereo.size(1) == 2)

81
        result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
David Pollack's avatar
David Pollack committed
82
83
84

        self.assertTrue(result.size(1) == 1)

85
        repr_test = transforms.DownmixMono(channels_first=False)
86
        self.assertTrue(repr_test.__repr__())
87

88
89
90
91
92
93
    def test_lc2cl(self):

        audio = self.sig.clone()
        result = transforms.LC2CL()(audio)
        self.assertTrue(result.size()[::-1] == audio.size())

94
        repr_test = transforms.LC2CL()
95
        self.assertTrue(repr_test.__repr__())
96

David Pollack's avatar
David Pollack committed
97
98
99
100
101
    def test_compose(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)
102
        maxminmax = float(max(abs(audio_orig.min()), abs(audio_orig.max())))
David Pollack's avatar
David Pollack committed
103
104

        tset = (transforms.Scale(factor=maxminmax),
105
                transforms.PadTrim(max_len=length_new, channels_first=False))
David Pollack's avatar
David Pollack committed
106
107
        result = transforms.Compose(tset)(audio_orig)

108
        self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
David Pollack's avatar
David Pollack committed
109
110
111

        self.assertTrue(result.size(0) == length_new)

112
        repr_test = transforms.Compose(tset)
113
        self.assertTrue(repr_test.__repr__())
114

David Pollack's avatar
David Pollack committed
115
116
117
118
119
120
121
122
123
124
125
126
127
    def test_mu_law_companding(self):

        quantization_channels = 256

        sig = self.sig.clone()
        sig = sig / torch.abs(sig).max()
        self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)

        sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
        self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)

        sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
        self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
David Pollack's avatar
David Pollack committed
128

129
        repr_test = transforms.MuLawEncoding(quantization_channels)
130
        self.assertTrue(repr_test.__repr__())
131
        repr_test = transforms.MuLawExpanding(quantization_channels)
132
        self.assertTrue(repr_test.__repr__())
133

134
    def test_mel2(self):
PCerles's avatar
PCerles committed
135
136
137
        top_db = 80.
        s2db = transforms.SpectrogramToDB("power", top_db)

138
139
140
        audio_orig = self.sig.clone()  # (16000, 1)
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
141
        mel_transform = transforms.MelSpectrogram()
142
        # check defaults
PCerles's avatar
PCerles committed
143
        spectrogram_torch = s2db(mel_transform(audio_scaled))  # (1, 319, 40)
144
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
145
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
146
        self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
147
148
149
150
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
        # check options
151
152
        kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50}
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
PCerles's avatar
PCerles committed
153
        spectrogram2_torch = s2db(mel_transform2(audio_scaled))  # (1, 506, 50)
154
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
155
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
156
157
158
159
        self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
160
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
PCerles's avatar
PCerles committed
161
        spectrogram_stereo = s2db(mel_transform(x_stereo))
162
163
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
164
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
165
        self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
166
        # check filterbank matrix creation
167
        fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
168
169
170
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
171

PCerles's avatar
PCerles committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    def test_mfcc(self):
        audio_orig = self.sig.clone()
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[1] == 321)
        # check melkwargs are passed through
        melkwargs = {'ws': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)
        self.assertTrue(torch_mfcc2.shape[1] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)

        norm_check = torch_mfcc.clone()
204
205
        norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
        norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
206
207
208

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

209
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
210
    def test_librosa_consistency(self):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
            sound_librosa = sound.cpu().numpy().squeeze().T  # squeeze batch and channel first

            # test core spectrogram
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

            out_torch = spect_transform(sound).squeeze().cpu().t()
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
            melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
                                                                      hop=hop_length, n_mels=n_mels, n_fft=n_fft)
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
            torch_mel = melspect_transform(sound).squeeze().cpu().t()

            # lower tolerance, think it's double vs. float
            self.assertTrue(torch.allclose(torch_mel.type(torch.double), torch.from_numpy(librosa_mel), atol=5e-3))

            # test s2db

            db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t()
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)

            self.assertTrue(torch.allclose(db_torch.type(torch.double), torch.from_numpy(db_librosa), atol=5e-3))

            # test MFCC
            melkwargs = {'hop': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
            torch_mfcc = mfcc_transform(sound).squeeze().cpu().t()

            self.assertTrue(torch.allclose(torch_mfcc.type(torch.double), torch.from_numpy(librosa_mfcc), atol=5e-3))

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
305
306


David Pollack's avatar
David Pollack committed
307
308
if __name__ == '__main__':
    unittest.main()