test_transforms.py 12.9 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
8
9
import torchaudio.transforms as transforms
import unittest
10
import common_utils
David Pollack's avatar
David Pollack committed
11

12
13
14
15
16
17
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
18

David Pollack's avatar
David Pollack committed
19
20
class Tester(unittest.TestCase):

21
    # create a sinewave signal for testing
22
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
23
    freq = 440
24
    volume = .3
25
26
27
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
28
    # file for stereo stft test
29
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
30
31
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
32

33
34
35
36
37
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
38

David Pollack's avatar
David Pollack committed
39
40
41
42
    def test_mu_law_companding(self):

        quantization_channels = 256

43
44
45
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
46

47
48
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
49

50
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
51
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

86
    def test_mel2(self):
PCerles's avatar
PCerles committed
87
        top_db = 80.
88
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
89

90
91
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
92
        mel_transform = transforms.MelSpectrogram()
93
        # check defaults
94
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
95
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
96
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
97
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
98
        # check correctness of filterbank conversion matrix
99
100
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
101
        # check options
102
103
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
104
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
105
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
106
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
107
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
108
109
110
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
111
        # check on multi-channel audio
112
113
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
114
115
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
116
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
117
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
118
        # check filterbank matrix creation
119
120
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
121
122
123
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
124

PCerles's avatar
PCerles committed
125
    def test_mfcc(self):
126
127
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
128
129
130
131

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
132
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
133
134
135
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
136
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
137
        self.assertTrue(torch_mfcc.dim() == 3)
138
139
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
140
        # check melkwargs are passed through
141
142
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
143
144
145
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
146
147
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
148
149

        # check norms work correctly
150
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
151
152
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
153
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
154
155

        norm_check = torch_mfcc.clone()
156
157
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
158
159
160

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

161
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
162
    def test_librosa_consistency(self):
163
164
165
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
166
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
167
168

            # test core spectrogram
169
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
170
171
172
173
174
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

175
            out_torch = spect_transform(sound).squeeze().cpu()
176
177
178
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
179
180
181
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
182
183
184
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
185
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
186
            torch_mel = melspect_transform(sound).squeeze().cpu()
187

jamarshon's avatar
jamarshon committed
188
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
189
190

            # test s2db
191
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
192
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
193
194
195
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

196
            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
197
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
198
            db_librosa_tensor = torch.from_numpy(db_librosa)
199

jamarshon's avatar
jamarshon committed
200
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
201
202

            # test MFCC
203
204
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
224
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
225
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
226

jamarshon's avatar
jamarshon committed
227
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
259

jamarshon's avatar
jamarshon committed
260
261
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
262
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
263
264
265
266
267

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

268
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
269
270
271

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
272
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
273
274

        # we expect the upsampled signal to have twice as many samples
275
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
276
277
278

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
279
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
280
281

        # we expect the downsampled signal to have half as many samples
282
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
283

David Pollack's avatar
David Pollack committed
284
285
if __name__ == '__main__':
    unittest.main()