test_transforms.py 11.5 KB
Newer Older
1
from __future__ import print_function
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
8
9
import torchaudio.transforms as transforms
import unittest
10
import common_utils
David Pollack's avatar
David Pollack committed
11

12
13
14
15
16
17
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
18

David Pollack's avatar
David Pollack committed
19
20
class Tester(unittest.TestCase):

21
    # create a sinewave signal for testing
22
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
23
    freq = 440
24
    volume = .3
25
26
27
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
28
    # file for stereo stft test
29
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
30
31
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
32

33
34
35
36
37
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
38

David Pollack's avatar
David Pollack committed
39
40
41
42
    def test_mu_law_companding(self):

        quantization_channels = 256

43
44
45
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
46

47
48
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
49

50
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
51
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
52

53
    def test_mel2(self):
PCerles's avatar
PCerles committed
54
        top_db = 80.
55
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
56

57
58
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
59
        mel_transform = transforms.MelSpectrogram()
60
        # check defaults
61
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
62
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
63
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
64
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
65
        # check correctness of filterbank conversion matrix
66
67
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
68
        # check options
69
70
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
71
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
72
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
73
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
74
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
75
76
77
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
78
        # check on multi-channel audio
79
80
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
81
82
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
83
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
84
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
85
        # check filterbank matrix creation
86
87
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
88
89
90
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
91

PCerles's avatar
PCerles committed
92
    def test_mfcc(self):
93
94
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
95
96
97
98

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
99
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
100
101
102
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
103
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
104
        self.assertTrue(torch_mfcc.dim() == 3)
105
106
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
107
        # check melkwargs are passed through
108
109
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
110
111
112
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
113
114
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
115
116

        # check norms work correctly
117
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
118
119
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
120
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
121
122

        norm_check = torch_mfcc.clone()
123
124
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
125
126
127

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

128
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
129
    def test_librosa_consistency(self):
130
131
132
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
133
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
134
135

            # test core spectrogram
136
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
137
138
139
140
141
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

142
            out_torch = spect_transform(sound).squeeze().cpu()
143
144
145
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
146
147
148
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
149
150
151
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
152
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
153
            torch_mel = melspect_transform(sound).squeeze().cpu()
154

jamarshon's avatar
jamarshon committed
155
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
156
157

            # test s2db
158
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
159
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
160
161
162
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

163
            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
164
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
165
            db_librosa_tensor = torch.from_numpy(db_librosa)
166

jamarshon's avatar
jamarshon committed
167
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
168
169

            # test MFCC
170
171
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
191
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
192
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
193

jamarshon's avatar
jamarshon committed
194
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
226

jamarshon's avatar
jamarshon committed
227
228
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
229
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
230
231
232
233
234

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

235
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
236
237
238

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
239
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
240
241

        # we expect the upsampled signal to have twice as many samples
242
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
243
244
245

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
246
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
247
248

        # we expect the downsampled signal to have half as many samples
249
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
250

David Pollack's avatar
David Pollack committed
251
252
if __name__ == '__main__':
    unittest.main()