test_transforms.py 13.1 KB
Newer Older
1
from __future__ import print_function
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
8
9
import torchaudio.transforms as transforms
import unittest
10
import test.common_utils
David Pollack's avatar
David Pollack committed
11

12
13
14
15
16
17
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
18

David Pollack's avatar
David Pollack committed
19
20
class Tester(unittest.TestCase):

21
    # create a sinewave signal for testing
David Pollack's avatar
David Pollack committed
22
23
    sr = 16000
    freq = 440
24
    volume = .3
25
    sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
26
    sig.unsqueeze_(1)  # (64000, 1)
Soumith Chintala's avatar
Soumith Chintala committed
27
    sig = (sig * volume * 2**31).long()
28
    # file for stereo stft test
29
    test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
30
31
    test_filepath = os.path.join(test_dirpath, "assets",
                                 "steam-train-whistle-daniel_simon.mp3")
David Pollack's avatar
David Pollack committed
32
33
34
35
36

    def test_scale(self):

        audio_orig = self.sig.clone()
        result = transforms.Scale()(audio_orig)
37
        self.assertTrue(result.min() >= -1. and result.max() <= 1.)
David Pollack's avatar
David Pollack committed
38

39
        maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
David Pollack's avatar
David Pollack committed
40
        result = transforms.Scale(factor=maxminmax)(audio_orig)
41

David Pollack's avatar
David Pollack committed
42
        self.assertTrue((result.min() == -1. or result.max() == 1.) and
43
                        result.min() >= -1. and result.max() <= 1.)
David Pollack's avatar
David Pollack committed
44

45
        repr_test = transforms.Scale()
46
        self.assertTrue(repr_test.__repr__())
47

David Pollack's avatar
David Pollack committed
48
49
50
51
52
53
    def test_pad_trim(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)

54
55
        result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
        self.assertEqual(result.size(0), length_new)
David Pollack's avatar
David Pollack committed
56

57
58
59
        result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
        self.assertEqual(result.size(1), length_new)

David Pollack's avatar
David Pollack committed
60
61
62
63
        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 0.8)

64
        result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
David Pollack's avatar
David Pollack committed
65

66
        self.assertEqual(result.size(0), length_new)
David Pollack's avatar
David Pollack committed
67

68
        repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
69
        self.assertTrue(repr_test.__repr__())
70

David Pollack's avatar
David Pollack committed
71
    def test_downmix_mono(self):
David Pollack's avatar
David Pollack committed
72

David Pollack's avatar
David Pollack committed
73
74
75
76
77
78
79
80
81
        audio_L = self.sig.clone()
        audio_R = self.sig.clone()
        R_idx = int(audio_R.size(0) * 0.1)
        audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))

        audio_Stereo = torch.cat((audio_L, audio_R), dim=1)

        self.assertTrue(audio_Stereo.size(1) == 2)

82
        result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
David Pollack's avatar
David Pollack committed
83
84
85

        self.assertTrue(result.size(1) == 1)

86
        repr_test = transforms.DownmixMono(channels_first=False)
87
        self.assertTrue(repr_test.__repr__())
88

89
90
91
92
93
94
    def test_lc2cl(self):

        audio = self.sig.clone()
        result = transforms.LC2CL()(audio)
        self.assertTrue(result.size()[::-1] == audio.size())

95
        repr_test = transforms.LC2CL()
96
        self.assertTrue(repr_test.__repr__())
97

David Pollack's avatar
David Pollack committed
98
99
100
101
102
    def test_compose(self):

        audio_orig = self.sig.clone()
        length_orig = audio_orig.size(0)
        length_new = int(length_orig * 1.2)
103
        maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
David Pollack's avatar
David Pollack committed
104
105

        tset = (transforms.Scale(factor=maxminmax),
106
                transforms.PadTrim(max_len=length_new, channels_first=False))
David Pollack's avatar
David Pollack committed
107
108
        result = transforms.Compose(tset)(audio_orig)

109
        self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
David Pollack's avatar
David Pollack committed
110
111
112

        self.assertTrue(result.size(0) == length_new)

113
        repr_test = transforms.Compose(tset)
114
        self.assertTrue(repr_test.__repr__())
115

David Pollack's avatar
David Pollack committed
116
117
118
119
120
121
122
123
124
125
126
127
128
    def test_mu_law_companding(self):

        quantization_channels = 256

        sig = self.sig.clone()
        sig = sig / torch.abs(sig).max()
        self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)

        sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
        self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)

        sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
        self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
David Pollack's avatar
David Pollack committed
129

130
        repr_test = transforms.MuLawEncoding(quantization_channels)
131
        self.assertTrue(repr_test.__repr__())
132
        repr_test = transforms.MuLawExpanding(quantization_channels)
133
        self.assertTrue(repr_test.__repr__())
134

135
    def test_mel2(self):
PCerles's avatar
PCerles committed
136
137
138
        top_db = 80.
        s2db = transforms.SpectrogramToDB("power", top_db)

139
140
141
        audio_orig = self.sig.clone()  # (16000, 1)
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)
142
        mel_transform = transforms.MelSpectrogram()
143
        # check defaults
PCerles's avatar
PCerles committed
144
        spectrogram_torch = s2db(mel_transform(audio_scaled))  # (1, 319, 40)
145
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
146
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
147
        self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
148
149
150
151
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
        # check options
152
153
        kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50}
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
PCerles's avatar
PCerles committed
154
        spectrogram2_torch = s2db(mel_transform2(audio_scaled))  # (1, 506, 50)
155
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
156
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
157
158
159
160
        self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
161
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
PCerles's avatar
PCerles committed
162
        spectrogram_stereo = s2db(mel_transform(x_stereo))
163
164
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
165
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
166
        self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
167
        # check filterbank matrix creation
168
        fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
169
170
171
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
172

PCerles's avatar
PCerles committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def test_mfcc(self):
        audio_orig = self.sig.clone()
        audio_scaled = transforms.Scale()(audio_orig)  # (16000, 1)
        audio_scaled = transforms.LC2CL()(audio_scaled)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[1] == 321)
        # check melkwargs are passed through
        melkwargs = {'ws': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)
        self.assertTrue(torch_mfcc2.shape[1] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)

        norm_check = torch_mfcc.clone()
205
206
        norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
        norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
207
208
209

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

210
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
211
    def test_librosa_consistency(self):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
            sound_librosa = sound.cpu().numpy().squeeze().T  # squeeze batch and channel first

            # test core spectrogram
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

            out_torch = spect_transform(sound).squeeze().cpu().t()
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
            melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
                                                                      hop=hop_length, n_mels=n_mels, n_fft=n_fft)
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
233
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
234
235
            torch_mel = melspect_transform(sound).squeeze().cpu().t()

jamarshon's avatar
jamarshon committed
236
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
237
238
239
240
241
242
243
244
245
246

            # test s2db

            db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t()
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
247
            db_librosa_tensor = torch.from_numpy(db_librosa)
248

jamarshon's avatar
jamarshon committed
249
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

            # test MFCC
            melkwargs = {'hop': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
273
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
274
275
            torch_mfcc = mfcc_transform(sound).squeeze().cpu().t()

jamarshon's avatar
jamarshon committed
276
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
308
309


David Pollack's avatar
David Pollack committed
310
311
if __name__ == '__main__':
    unittest.main()