test_sox_compatibility.py 14.2 KB
Newer Older
1
2
import unittest

3
4
5
import torch
import torchaudio
import torchaudio.functional as F
6
7
import torchaudio.transforms as T

8
9
import common_utils
from common_utils import AudioBackendScope, BACKENDS
10
11
12


class TestFunctionalFiltering(unittest.TestCase):
13
14
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
15
    def test_gain(self):
16
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
17
18
19
20
21
22
23
24
25
26
        waveform, _ = torchaudio.load(test_filepath)

        waveform_gain = F.gain(waveform, 3)
        self.assertTrue(waveform_gain.abs().max().item(), 1.)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath)
        E.append_effect_to_chain("gain", [3])
        sox_gain_waveform = E.sox_build_flow_effects()[0]

27
        torch.testing.assert_allclose(waveform_gain, sox_gain_waveform, atol=1e-04, rtol=1e-5)
28
29
30
31

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_dither(self):
32
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
33
34
35
36
37
38
39
40
41
42
        waveform, _ = torchaudio.load(test_filepath)

        waveform_dithered = F.dither(waveform)
        waveform_dithered_noiseshaped = F.dither(waveform, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath)
        E.append_effect_to_chain("dither", [])
        sox_dither_waveform = E.sox_build_flow_effects()[0]

43
        torch.testing.assert_allclose(waveform_dithered, sox_dither_waveform, atol=1e-04, rtol=1e-5)
44
45
46
47
48
        E.clear_chain()

        E.append_effect_to_chain("dither", ["-s"])
        sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

49
        torch.testing.assert_allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02, rtol=1e-5)
50
51
52
53

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_vctk_transform_pipeline(self):
54
        test_filepath_vctk = common_utils.get_asset_path('VCTK-Corpus', 'wav48', 'p224', 'p224_002.wav')
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

        # rate
        sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
        wf_vctk = sample(wf_vctk)
        # dither
        wf_vctk = F.dither(wf_vctk, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath_vctk)
        E.append_effect_to_chain("gain", ["-h"])
        E.append_effect_to_chain("channels", [1])
        E.append_effect_to_chain("rate", [16000])
        E.append_effect_to_chain("gain", ["-rh"])
        E.append_effect_to_chain("dither", ["-s"])
        wf_vctk_sox = E.sox_build_flow_effects()[0]

72
        torch.testing.assert_allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)
73
74
75

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
76
77
78
79
80
    def test_lowpass(self):
        """
        Test biquad lowpass filter, compare to SoX implementation
        """

81
        cutoff_freq = 3000
82

83
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
84
85
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
86
        E.append_effect_to_chain("lowpass", [cutoff_freq])
87
88
        sox_output_waveform, sr = E.sox_build_flow_effects()

89
        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
90
        output_waveform = F.lowpass_biquad(waveform, sample_rate, cutoff_freq)
91

92
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
93

94
95
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
96
97
98
99
100
    def test_highpass(self):
        """
        Test biquad highpass filter, compare to SoX implementation
        """

101
        cutoff_freq = 2000
102

103
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
104
105
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
106
        E.append_effect_to_chain("highpass", [cutoff_freq])
107
108
        sox_output_waveform, sr = E.sox_build_flow_effects()

109
        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
110
        output_waveform = F.highpass_biquad(waveform, sample_rate, cutoff_freq)
111
112

        # TBD - this fails at the 1e-4 level, debug why
113
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-3, rtol=1e-5)
114

115
116
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
moto's avatar
moto committed
117
118
119
120
121
    def test_allpass(self):
        """
        Test biquad allpass filter, compare to SoX implementation
        """

122
123
        central_freq = 1000
        q = 0.707
moto's avatar
moto committed
124

125
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
moto's avatar
moto committed
126
127
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
128
        E.append_effect_to_chain("allpass", [central_freq, str(q) + 'q'])
moto's avatar
moto committed
129
130
131
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
132
        output_waveform = F.allpass_biquad(waveform, sample_rate, central_freq, q)
moto's avatar
moto committed
133

134
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
135

136
137
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
138
139
140
141
142
    def test_bandpass_with_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

143
144
145
        central_freq = 1000
        q = 0.707
        const_skirt_gain = True
146

147
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
148
149
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
150
        E.append_effect_to_chain("bandpass", ["-c", central_freq, str(q) + 'q'])
151
152
153
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
154
        output_waveform = F.bandpass_biquad(waveform, sample_rate, central_freq, q, const_skirt_gain)
155

156
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
157

158
159
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
160
161
162
163
164
    def test_bandpass_without_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

165
166
167
        central_freq = 1000
        q = 0.707
        const_skirt_gain = False
168

169
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
170
171
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
172
        E.append_effect_to_chain("bandpass", [central_freq, str(q) + 'q'])
173
174
175
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
176
        output_waveform = F.bandpass_biquad(waveform, sample_rate, central_freq, q, const_skirt_gain)
177

178
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
179

180
181
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
182
183
184
185
186
    def test_bandreject(self):
        """
        Test biquad bandreject filter, compare to SoX implementation
        """

187
188
        central_freq = 1000
        q = 0.707
189

190
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
191
192
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
193
        E.append_effect_to_chain("bandreject", [central_freq, str(q) + 'q'])
194
195
196
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
197
        output_waveform = F.bandreject_biquad(waveform, sample_rate, central_freq, q)
198

199
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
moto's avatar
moto committed
200

201
202
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
203
204
205
206
207
    def test_band_with_noise(self):
        """
        Test biquad band filter with noise mode, compare to SoX implementation
        """

208
209
210
        central_freq = 1000
        q = 0.707
        noise = True
211

212
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
213
214
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
215
        E.append_effect_to_chain("band", ["-n", central_freq, str(q) + 'q'])
216
217
218
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
219
        output_waveform = F.band_biquad(waveform, sample_rate, central_freq, q, noise)
220

221
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
222

223
224
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
225
226
227
228
229
    def test_band_without_noise(self):
        """
        Test biquad band filter without noise mode, compare to SoX implementation
        """

230
231
232
        central_freq = 1000
        q = 0.707
        noise = False
233

234
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
235
236
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
237
        E.append_effect_to_chain("band", [central_freq, str(q) + 'q'])
238
239
240
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
241
        output_waveform = F.band_biquad(waveform, sample_rate, central_freq, q, noise)
242

243
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
244

245
246
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
247
248
249
250
251
    def test_treble(self):
        """
        Test biquad treble filter, compare to SoX implementation
        """

252
253
254
        central_freq = 1000
        q = 0.707
        gain = 40
255

256
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
257
258
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
259
        E.append_effect_to_chain("treble", [gain, central_freq, str(q) + 'q'])
260
261
262
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
263
        output_waveform = F.treble_biquad(waveform, sample_rate, gain, central_freq, q)
264

265
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
266

267
268
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
269
270
271
272
273
    def test_deemph(self):
        """
        Test biquad deemph filter, compare to SoX implementation
        """

274
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
275
276
277
278
279
280
281
282
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("deemph")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.deemph_biquad(waveform, sample_rate)

283
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
284

285
286
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
287
288
289
290
291
    def test_riaa(self):
        """
        Test biquad riaa filter, compare to SoX implementation
        """

292
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
293
294
295
296
297
298
299
300
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("riaa")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.riaa_biquad(waveform, sample_rate)

301
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
302

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_contrast(self):
        """
        Test contrast effect, compare to SoX implementation
        """
        enhancement_amount = 80.
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("contrast", [enhancement_amount])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.contrast(waveform, enhancement_amount)

        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)

321
322
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
xinyang0's avatar
xinyang0 committed
323
324
325
326
327
    def test_equalizer(self):
        """
        Test biquad peaking equalizer filter, compare to SoX implementation
        """

328
329
330
        center_freq = 300
        q = 0.707
        gain = 1
xinyang0's avatar
xinyang0 committed
331

332
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
xinyang0's avatar
xinyang0 committed
333
334
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
335
        E.append_effect_to_chain("equalizer", [center_freq, q, gain])
xinyang0's avatar
xinyang0 committed
336
337
338
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
339
        output_waveform = F.equalizer_biquad(waveform, sample_rate, center_freq, gain, q)
xinyang0's avatar
xinyang0 committed
340

341
        torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
xinyang0's avatar
xinyang0 committed
342

343
344
    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
345
346
    def test_perf_biquad_filtering(self):

347
        fn_sine = common_utils.get_asset_path('whitenoise.wav')
348
349
350
351
352
353
354
355
356
357
358
359

        b0 = 0.4
        b1 = 0.2
        b2 = 0.9
        a0 = 0.7
        a1 = 0.2
        a2 = 0.6

        # SoX method
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(fn_sine)
        E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
moto's avatar
moto committed
360
        waveform_sox_out, _ = E.sox_build_flow_effects()
361

moto's avatar
moto committed
362
        waveform, _ = torchaudio.load(fn_sine, normalization=True)
363
364
365
366
        waveform_lfilter_out = F.lfilter(
            waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
        )

367
        torch.testing.assert_allclose(waveform_lfilter_out, waveform_sox_out, atol=1e-4, rtol=1e-5)
368
369
370
371


if __name__ == "__main__":
    unittest.main()