transforms_test.py 10.3 KB
Newer Older
1
2
import math

David Pollack's avatar
David Pollack committed
3
4
5
import torch
import torchaudio
import torchaudio.transforms as transforms
Vincent QB's avatar
Vincent QB committed
6
import torchaudio.functional as F
David Pollack's avatar
David Pollack committed
7

8
from torchaudio_unittest import common_utils
9
10


moto's avatar
moto committed
11
12
class Tester(common_utils.TorchaudioTestCase):
    backend = 'default'
David Pollack's avatar
David Pollack committed
13

14
    # create a sinewave signal for testing
15
    sample_rate = 16000
David Pollack's avatar
David Pollack committed
16
    freq = 440
17
    volume = .3
18
19
20
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
David Pollack's avatar
David Pollack committed
21

moto's avatar
moto committed
22
    def scale(self, waveform, factor=2.0**31):
23
24
25
26
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor
27

David Pollack's avatar
David Pollack committed
28
29
30
31
    def test_mu_law_companding(self):

        quantization_channels = 256

32
        waveform = self.waveform.clone()
33
34
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
35
        waveform /= torch.abs(waveform).max()
36

37
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
David Pollack's avatar
David Pollack committed
38

39
40
        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
David Pollack's avatar
David Pollack committed
41

42
        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
43
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
44

45
    def test_AmplitudeToDB(self):
46
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
47
        waveform = common_utils.load_wav(filepath)[0]
48
49
50
51
52
53
54

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

55
        self.assertEqual(mag_to_db_torch, power_to_db_torch)
56

57
58
59
60
61
62
63
64
65
66
67
68
    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
69
        self.assertEqual(fb, fb_copy)
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

85
        self.assertEqual(window, window_copy)
86
87
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
88
        self.assertEqual(fb, fb_copy)
89

90
    def test_mel2(self):
PCerles's avatar
PCerles committed
91
        top_db = 80.
92
        s2db = transforms.AmplitudeToDB('power', top_db)
PCerles's avatar
PCerles committed
93

94
95
        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
96
        mel_transform = transforms.MelSpectrogram()
97
        # check defaults
98
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
99
        self.assertTrue(spectrogram_torch.dim() == 3)
PCerles's avatar
PCerles committed
100
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
101
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
102
        # check correctness of filterbank conversion matrix
103
104
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
105
        # check options
106
107
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
108
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
109
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
110
        self.assertTrue(spectrogram2_torch.dim() == 3)
PCerles's avatar
PCerles committed
111
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
112
113
114
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
115
        # check on multi-channel audio
116
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
117
        x_stereo = common_utils.load_wav(filepath)[0]  # (2, 278756), 44100
118
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
119
120
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
PCerles's avatar
PCerles committed
121
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
122
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
123
        # check filterbank matrix creation
124
125
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
126
127
128
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Soumith Chintala's avatar
Soumith Chintala committed
129

PCerles's avatar
PCerles committed
130
    def test_mfcc(self):
131
132
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)
PCerles's avatar
PCerles committed
133
134
135
136

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
137
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
138
139
140
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
141
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
142
        self.assertTrue(torch_mfcc.dim() == 3)
143
144
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
PCerles's avatar
PCerles committed
145
        # check melkwargs are passed through
146
147
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
148
149
150
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
151
152
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)
PCerles's avatar
PCerles committed
153
154

        # check norms work correctly
155
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
PCerles's avatar
PCerles committed
156
157
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
158
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)
PCerles's avatar
PCerles committed
159
160

        norm_check = torch_mfcc.clone()
161
162
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
PCerles's avatar
PCerles committed
163
164
165

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

jamarshon's avatar
jamarshon committed
166
    def test_resample_size(self):
167
        input_path = common_utils.get_asset_path('sinewave.wav')
168
        waveform, sample_rate = common_utils.load_wav(input_path)
jamarshon's avatar
jamarshon committed
169
170
171
172
173

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

174
        self.assertRaises(ValueError, invalid_resample, waveform)
jamarshon's avatar
jamarshon committed
175
176
177

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
178
        up_sampled = upsample_resample(waveform)
jamarshon's avatar
jamarshon committed
179
180

        # we expect the upsampled signal to have twice as many samples
181
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
jamarshon's avatar
jamarshon committed
182
183
184

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
185
        down_sampled = downsample_resample(waveform)
jamarshon's avatar
jamarshon committed
186
187

        # we expect the downsampled signal to have half as many samples
188
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
PCerles's avatar
PCerles committed
189

Vincent QB's avatar
Vincent QB committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
211
        self.assertEqual(computed_functional, computed_transform, atol=atol, rtol=rtol)
Vincent QB's avatar
Vincent QB committed
212
213
214

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
215
216
217
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas(win_length=3)
Vincent QB's avatar
Vincent QB committed
218
        computed = transform(specgram)
219
        assert computed.shape == expected.shape, (computed.shape, expected.shape)
220
        self.assertEqual(computed, expected, atol=1e-6, rtol=1e-8)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236


class SmokeTest(common_utils.TorchaudioTestCase):

    def test_spectrogram(self):
        specgram = transforms.Spectrogram(center=False, pad_mode="reflect", onesided=False)
        self.assertEqual(specgram.center, False)
        self.assertEqual(specgram.pad_mode, "reflect")
        self.assertEqual(specgram.onesided, False)

    def test_melspectrogram(self):
        melspecgram = transforms.MelSpectrogram(center=True, pad_mode="reflect", onesided=False)
        specgram = melspecgram.spectrogram
        self.assertEqual(specgram.center, True)
        self.assertEqual(specgram.pad_mode, "reflect")
        self.assertEqual(specgram.onesided, False)