test_transforms.py 9.61 KB
Newer Older
1
import math
2
import unittest
3

David Pollack's avatar
David Pollack committed
4
import torch
5
from torch.testing._internal.common_utils import TestCase
David Pollack's avatar
David Pollack committed
6
7
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
import torchaudio.functional as F
David Pollack's avatar
David Pollack committed
9

10
import common_utils
11
12


13
class Tester(TestCase):
David Pollack's avatar
David Pollack committed
14

15
    # create a sinewave signal for testing
16
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
17
    freq = 440
18
    volume = .3
19
20
21
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
David Pollack's avatar
David Pollack committed
22

moto's avatar
moto committed
23
    def scale(self, waveform, factor=2.0**31):
24
25
26
27
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
28

David Pollack's avatar
David Pollack committed
29
30
31
32
    def test_mu_law_companding(self):

        quantization_channels = 256

33
34
35
        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
36

37
38
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
39

40
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
41
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
42

43
    def test_AmplitudeToDB(self):
44
45
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, sample_rate = torchaudio.load(filepath)
46
47
48
49
50
51
52

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

53
        self.assertEqual(mag_to_db_torch, power_to_db_torch)
54

55
56
57
58
59
60
61
62
63
64
65
66
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
67
        self.assertEqual(fb, fb_copy)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

83
        self.assertEqual(window, window_copy)
84
85
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
86
        self.assertEqual(fb, fb_copy)
87

88
    def test_mel2(self):
PCerles's avatar
PCerles committed
89
        top_db = 80.
90
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
91

92
93
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
94
        mel_transform = transforms.MelSpectrogram()
95
        # check defaults
96
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
97
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
98
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
99
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
100
        # check correctness of filterbank conversion matrix
101
102
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
103
        # check options
104
105
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
106
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
107
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
108
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
109
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
110
111
112
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
113
        # check on multi-channel audio
114
115
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        x_stereo, sr_stereo = torchaudio.load(filepath)  # (2, 278756), 44100
116
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
117
118
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
119
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
120
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
121
        # check filterbank matrix creation
122
123
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
124
125
126
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
127

PCerles's avatar
PCerles committed
128
    def test_mfcc(self):
129
130
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
131
132
133
134

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
135
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
136
137
138
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
139
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
140
        self.assertTrue(torch_mfcc.dim() == 3)
141
142
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
143
        # check melkwargs are passed through
144
145
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
146
147
148
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
149
150
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
151
152

        # check norms work correctly
153
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
154
155
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
156
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
157
158

        norm_check = torch_mfcc.clone()
159
160
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
161
162
163

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

jamarshon's avatar
jamarshon committed
164
    def test_resample_size(self):
165
        input_path = common_utils.get_asset_path('sinewave.wav')
166
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
167
168
169
170
171

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

172
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
173
174
175

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
176
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
177
178

        # we expect the upsampled signal to have twice as many samples
179
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
180
181
182

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
183
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
184
185

        # we expect the downsampled signal to have half as many samples
186
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
187

Vincent QB's avatar
Vincent QB committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
209
        self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol)
Vincent QB's avatar
Vincent QB committed
210
211
212

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
213
214
215
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas(win_length=3)
Vincent QB's avatar
Vincent QB committed
216
        computed = transform(specgram)
217
        assert computed.shape == expected.shape, (computed.shape, expected.shape)
218
        self.assertEqual(computed, expected, atol=1e-6, rtol=1e-8)
Vincent QB's avatar
Vincent QB committed
219
220


David Pollack's avatar
David Pollack committed
221
222
if __name__ == '__main__':
    unittest.main()