test_transforms.py 9.72 KB
Newer Older
1
import math
2
import unittest
3

David Pollack's avatar
David Pollack committed
4
import torch
5
from torch.testing._internal.common_utils import TestCase
David Pollack's avatar
David Pollack committed
6
7
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
8
import torchaudio.functional as F
David Pollack's avatar
David Pollack committed
9

10
from . import common_utils
11
12


13
class Tester(TestCase):
David Pollack's avatar
David Pollack committed
14

15
    # create a sinewave signal for testing
16
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
17
    freq = 440
18
    volume = .3
19
20
21
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
David Pollack's avatar
David Pollack committed
22

moto's avatar
moto committed
23
    def scale(self, waveform, factor=2.0**31):
24
25
26
27
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
28

David Pollack's avatar
David Pollack committed
29
30
31
32
    def test_mu_law_companding(self):

        quantization_channels = 256

33
        waveform = self.waveform.clone()
34
35
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
36
        waveform /= torch.abs(waveform).max()
37

38
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
39

40
41
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
42

43
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
44
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
45

46
    def test_AmplitudeToDB(self):
47
48
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, sample_rate = torchaudio.load(filepath)
49
50
51
52
53
54
55

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

56
        self.assertEqual(mag_to_db_torch, power_to_db_torch)
57

58
59
60
61
62
63
64
65
66
67
68
69
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
70
        self.assertEqual(fb, fb_copy)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

86
        self.assertEqual(window, window_copy)
87
88
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
89
        self.assertEqual(fb, fb_copy)
90

91
    def test_mel2(self):
PCerles's avatar
PCerles committed
92
        top_db = 80.
93
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
94

95
96
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
97
        mel_transform = transforms.MelSpectrogram()
98
        # check defaults
99
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
100
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
101
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
102
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
103
        # check correctness of filterbank conversion matrix
104
105
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
106
        # check options
107
108
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
109
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
110
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
111
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
112
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
113
114
115
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
116
        # check on multi-channel audio
117
118
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        x_stereo, sr_stereo = torchaudio.load(filepath)  # (2, 278756), 44100
119
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
120
121
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
122
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
123
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
124
        # check filterbank matrix creation
125
126
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
127
128
129
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
130

PCerles's avatar
PCerles committed
131
    def test_mfcc(self):
132
133
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
134
135
136
137

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
138
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
139
140
141
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
142
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
143
        self.assertTrue(torch_mfcc.dim() == 3)
144
145
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
146
        # check melkwargs are passed through
147
148
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
149
150
151
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
152
153
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
154
155

        # check norms work correctly
156
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
157
158
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
159
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
160
161

        norm_check = torch_mfcc.clone()
162
163
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
164
165
166

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

jamarshon's avatar
jamarshon committed
167
    def test_resample_size(self):
168
        input_path = common_utils.get_asset_path('sinewave.wav')
169
        waveform, sample_rate = torchaudio.load(input_path)
jamarshon's avatar
jamarshon committed
170
171
172
173
174

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

175
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
176
177
178

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
179
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
180
181

        # we expect the upsampled signal to have twice as many samples
182
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
183
184
185

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
186
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
187
188

        # we expect the downsampled signal to have half as many samples
189
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
190

Vincent QB's avatar
Vincent QB committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
212
        self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol)
Vincent QB's avatar
Vincent QB committed
213
214
215

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
216
217
218
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas(win_length=3)
Vincent QB's avatar
Vincent QB committed
219
        computed = transform(specgram)
220
        assert computed.shape == expected.shape, (computed.shape, expected.shape)
221
        self.assertEqual(computed, expected, atol=1e-6, rtol=1e-8)
Vincent QB's avatar
Vincent QB committed
222
223


David Pollack's avatar
David Pollack committed
224
225
if __name__ == '__main__':
    unittest.main()