test_transforms.py 9.64 KB
Newer Older
1
import math
2
import unittest
3

David Pollack's avatar
David Pollack committed
4
5
6
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
7
import torchaudio.functional as F
David Pollack's avatar
David Pollack committed
8

9
import common_utils
10
11


David Pollack's avatar
David Pollack committed
12
13
class Tester(unittest.TestCase):

14
    # create a sinewave signal for testing
15
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
16
    freq = 440
17
    volume = .3
18
19
20
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
David Pollack's avatar
David Pollack committed
21

moto's avatar
moto committed
22
    def scale(self, waveform, factor=2.0**31):
23
24
25
26
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
27

David Pollack's avatar
David Pollack committed
28
29
30
31
    def test_mu_law_companding(self):

        quantization_channels = 256

32
33
34
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
35

36
37
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
38

39
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
40
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
41

42
    def test_AmplitudeToDB(self):
43
44
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, sample_rate = torchaudio.load(filepath)
45
46
47
48
49
50
51

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

52
        torch.testing.assert_allclose(mag_to_db_torch, power_to_db_torch)
53

54
55
56
57
58
59
60
61
62
63
64
65
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
66
        torch.testing.assert_allclose(fb, fb_copy)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

82
        torch.testing.assert_allclose(window, window_copy)
83
84
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
85
        torch.testing.assert_allclose(fb, fb_copy)
86

87
    def test_mel2(self):
PCerles's avatar
PCerles committed
88
        top_db = 80.
89
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
90

91
92
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
93
        mel_transform = transforms.MelSpectrogram()
94
        # check defaults
95
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
96
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
97
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
98
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
99
        # check correctness of filterbank conversion matrix
100
101
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
102
        # check options
103
104
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
105
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
106
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
107
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
108
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
109
110
111
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
112
        # check on multi-channel audio
113
114
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        x_stereo, sr_stereo = torchaudio.load(filepath)  # (2, 278756), 44100
115
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
116
117
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
118
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
119
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
120
        # check filterbank matrix creation
121
122
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
123
124
125
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
126

PCerles's avatar
PCerles committed
127
    def test_mfcc(self):
128
129
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
130
131
132
133

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
134
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
135
136
137
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
138
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
139
        self.assertTrue(torch_mfcc.dim() == 3)
140
141
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
142
        # check melkwargs are passed through
143
144
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
145
146
147
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
148
149
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
150
151

        # check norms work correctly
152
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
153
154
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
155
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
156
157

        norm_check = torch_mfcc.clone()
158
159
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
160
161
162

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

jamarshon's avatar
jamarshon committed
163
    def test_resample_size(self):
164
        input_path = common_utils.get_asset_path('sinewave.wav')
165
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
166
167
168
169
170

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

171
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
172
173
174

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
175
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
176
177

        # we expect the upsampled signal to have twice as many samples
178
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
179
180
181

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
182
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
183
184

        # we expect the downsampled signal to have half as many samples
185
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
186

Vincent QB's avatar
Vincent QB committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
212
213
214
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas(win_length=3)
Vincent QB's avatar
Vincent QB committed
215
        computed = transform(specgram)
216
        assert computed.shape == expected.shape, (computed.shape, expected.shape)
217
        torch.testing.assert_allclose(computed, expected, atol=1e-6, rtol=1e-8)
Vincent QB's avatar
Vincent QB committed
218
219


David Pollack's avatar
David Pollack committed
220
221
if __name__ == '__main__':
    unittest.main()