"torchvision/models/vscode:/vscode.git/clone" did not exist on "eb00e2ad9da0df409d62033a1d7078572df67fb5"
test_transforms.py 11.9 KB
Newer Older
1
from __future__ import print_function
2
import math
3
import os
4

David Pollack's avatar
David Pollack committed
5
6
import torch
import torchaudio
7
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
David Pollack's avatar
David Pollack committed
8
9
import torchaudio.transforms as transforms
import unittest
10
import test.common_utils
David Pollack's avatar
David Pollack committed
11

12
13
14
15
16
17
if IMPORT_LIBROSA:
    import librosa

if IMPORT_SCIPY:
    import scipy

Soumith Chintala's avatar
Soumith Chintala committed
18

David Pollack's avatar
David Pollack committed
19
20
class Tester(unittest.TestCase):

21
    # create a sinewave signal for testing
22
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
23
    freq = 440
24
    volume = .3
25
26
27
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
28
    # file for stereo stft test
29
    test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
30
31
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
David Pollack's avatar
David Pollack committed
32

33
34
35
36
37
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
38

David Pollack's avatar
David Pollack committed
39
40
    def test_pad_trim(self):

41
42
        waveform = self.waveform.clone()
        length_orig = waveform.size(1)
David Pollack's avatar
David Pollack committed
43
44
        length_new = int(length_orig * 1.2)

45
        result = transforms.PadTrim(max_len=length_new)(waveform)
46
47
        self.assertEqual(result.size(1), length_new)

David Pollack's avatar
David Pollack committed
48
49
        length_new = int(length_orig * 0.8)

50
51
        result = transforms.PadTrim(max_len=length_new)(waveform)
        self.assertEqual(result.size(1), length_new)
52

David Pollack's avatar
David Pollack committed
53
54
55
56
    def test_mu_law_companding(self):

        quantization_channels = 256

57
58
59
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
60

61
62
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
63

64
65
        waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu)
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
66

67
    def test_mel2(self):
PCerles's avatar
PCerles committed
68
        top_db = 80.
69
        s2db = transforms.SpectrogramToDB('power', top_db)
PCerles's avatar
PCerles committed
70

71
72
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
73
        mel_transform = transforms.MelSpectrogram()
74
        # check defaults
75
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
76
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
77
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
78
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
79
        # check correctness of filterbank conversion matrix
80
81
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
82
        # check options
83
84
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
85
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
86
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
87
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
88
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
89
90
91
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
92
        # check on multi-channel audio
93
94
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
95
96
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
97
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
98
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
99
        # check filterbank matrix creation
100
101
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
102
103
104
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
105

PCerles's avatar
PCerles committed
106
    def test_mfcc(self):
107
108
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
109
110
111
112

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
113
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
114
115
116
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
117
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
118
        self.assertTrue(torch_mfcc.dim() == 3)
119
120
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
121
        # check melkwargs are passed through
122
123
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
124
125
126
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
127
128
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
129
130

        # check norms work correctly
131
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
132
133
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
134
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
135
136

        norm_check = torch_mfcc.clone()
137
138
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
139
140
141

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

142
    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
PCerles's avatar
PCerles committed
143
    def test_librosa_consistency(self):
144
145
146
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
147
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)
148
149

            # test core spectrogram
150
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
151
152
153
154
155
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

156
            out_torch = spect_transform(sound).squeeze().cpu()
157
158
159
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
160
161
162
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
163
164
165
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
jamarshon's avatar
jamarshon committed
166
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
167
            torch_mel = melspect_transform(sound).squeeze().cpu()
168

jamarshon's avatar
jamarshon committed
169
            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
170
171

            # test s2db
172
173
            db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.)
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
174
175
176
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

177
            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
178
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
jamarshon's avatar
jamarshon committed
179
            db_librosa_tensor = torch.from_numpy(db_librosa)
180

jamarshon's avatar
jamarshon committed
181
            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
182
183

            # test MFCC
184
185
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
jamarshon's avatar
jamarshon committed
205
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
206
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()
207

jamarshon's avatar
jamarshon committed
208
            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)
PCerles's avatar
PCerles committed
240

jamarshon's avatar
jamarshon committed
241
242
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
243
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
244
245
246
247
248

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

249
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
250
251
252

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
253
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
254
255

        # we expect the upsampled signal to have twice as many samples
256
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
257
258
259

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
260
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
261
262

        # we expect the downsampled signal to have half as many samples
263
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
264

David Pollack's avatar
David Pollack committed
265
266
if __name__ == '__main__':
    unittest.main()