test_transforms.py 9.71 KB
Newer Older
1
import math
2
import os
3
import unittest
4

David Pollack's avatar
David Pollack committed
5
6
7
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
import torchaudio.functional as F
David Pollack's avatar
David Pollack committed
9

10
from common_utils import AudioBackendScope, BACKENDS, create_temp_assets_dir
11
12


David Pollack's avatar
David Pollack committed
13
14
class Tester(unittest.TestCase):

15
    # create a sinewave signal for testing
16
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
17
    freq = 440
18
    volume = .3
19
20
21
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
22
    # file for stereo stft test
23
    test_dirpath, test_dir = create_temp_assets_dir()
24
    test_filepath = os.path.join(test_dirpath, 'assets',
25
                                 'steam-train-whistle-daniel_simon.wav')
David Pollack's avatar
David Pollack committed
26

27
28
29
30
31
    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
32

David Pollack's avatar
David Pollack committed
33
34
35
36
    def test_mu_law_companding(self):

        quantization_channels = 256

37
38
39
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
40

41
42
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
43

44
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
45
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
46

47
48
49
50
51
52
53
54
55
56
57
    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

91
    def test_mel2(self):
PCerles's avatar
PCerles committed
92
        top_db = 80.
93
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
94

95
96
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
97
        mel_transform = transforms.MelSpectrogram()
98
        # check defaults
99
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
100
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
101
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
102
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
103
        # check correctness of filterbank conversion matrix
104
105
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
106
        # check options
107
108
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
109
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
110
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
111
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
112
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
113
114
115
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
116
        # check on multi-channel audio
117
118
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
119
120
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
121
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
122
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
123
        # check filterbank matrix creation
124
125
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
126
127
128
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
129

PCerles's avatar
PCerles committed
130
    def test_mfcc(self):
131
132
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
133
134
135
136

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
137
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
138
139
140
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
141
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
142
        self.assertTrue(torch_mfcc.dim() == 3)
143
144
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
145
        # check melkwargs are passed through
146
147
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
148
149
150
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
151
152
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
153
154

        # check norms work correctly
155
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
156
157
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
158
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
159
160

        norm_check = torch_mfcc.clone()
161
162
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
163
164
165

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

jamarshon's avatar
jamarshon committed
166
167
    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
168
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
169
170
171
172
173

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

174
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
175
176
177

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
178
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
179
180

        # we expect the upsampled signal to have twice as many samples
181
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
182
183
184

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
185
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
186
187

        # we expect the downsampled signal to have half as many samples
188
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
189

Vincent QB's avatar
Vincent QB committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))


David Pollack's avatar
David Pollack committed
222
223
if __name__ == '__main__':
    unittest.main()