functional_impl.py 21.1 KB
Newer Older
dhthompson's avatar
dhthompson committed
1
2
3
4
5
6
"""Test definition common to CPU and CUDA"""
import math
import itertools
import warnings

import numpy as np
7
8
import torch
import torchaudio.functional as F
9
from parameterized import parameterized
10
from scipy import signal
11

12
13
14
15
16
17
18
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_sinusoid,
    nested_params,
    get_whitenoise,
    rnnt_utils,
)
19
20


dhthompson's avatar
dhthompson committed
21
class Functional(TestBaseMixin):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
                                         resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
        # resample the signal and compare it to the ground truth
        n_to_trim = 20
        sample_rate = 1000
        new_sample_rate = sample_rate

        if up_scale_factor is not None:
            new_sample_rate *= up_scale_factor

        if down_scale_factor is not None:
            new_sample_rate //= down_scale_factor

        duration = 5  # seconds
        original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

        sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
        estimate = F.resample(sound, sample_rate, new_sample_rate,
                              resampling_method=resampling_method).squeeze()

        new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
        ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

        # trim the first/last n samples as these points have boundary effects
        ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
        estimate = estimate[..., n_to_trim:-n_to_trim]

        self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)

51
52
53
54
55
56
57
58
59
    def _test_costs_and_gradients(
        self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2
    ):
        logits_shape = data["logits"].shape
        costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
        self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
        self.assertEqual(logits_shape, gradients.shape)
        self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)

60
    def test_lfilter_simple(self):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        """
        Create a very basic signal,
        Then make a simple 4th order delay
        The output should be same as the input but shifted
        """

        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device)
        b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
        a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)

        self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)

75
    def test_lfilter_clamp(self):
76
77
78
79
80
81
82
        input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
        b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
        a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
        output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True)
        assert output_signal.max() <= 1
        output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
        assert output_signal.max() > 1
83

84
    @parameterized.expand([
85
86
87
88
89
90
91
        ((44100,), (4,), (44100,)),
        ((3, 44100), (4,), (3, 44100,)),
        ((2, 3, 44100), (4,), (2, 3, 44100,)),
        ((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)),
        ((44100,), (2, 4), (2, 44100)),
        ((3, 44100), (1, 4), (3, 1, 44100)),
        ((1, 2, 44100), (3, 4), (1, 2, 3, 44100))
92
    ])
93
    def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
94
        torch.random.manual_seed(42)
95
96
97
        waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
        b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
        a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
98
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False)
99
100
        assert input_shape == waveform.size()
        assert target_shape == output_waveform.size()
101

102
    def test_lfilter_9th_order_filter_stability(self):
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        """
        Validate the precision of lfilter against reference scipy implementation when using high order filter.
        The reference implementation use cascaded second-order filters so is more numerically accurate.
        """
        # create an impulse signal
        x = torch.zeros(1024, dtype=self.dtype, device=self.device)
        x[0] = 1

        # get target impulse response
        sos = signal.butter(9, 850, 'hp', fs=22050, output='sos')
        y = torch.from_numpy(signal.sosfilt(sos, x.cpu().numpy())).to(self.dtype).to(self.device)

        # get lfilter coefficients
        b, a = signal.butter(9, 850, 'hp', fs=22050, output='ba')
        b, a = torch.from_numpy(b).to(self.dtype).to(self.device), torch.from_numpy(
            a).to(self.dtype).to(self.device)

        # predict impulse response
        yhat = F.lfilter(x, a, b, False)
        self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)

124
    @parameterized.expand([(0., ), (1., ), (2., ), (3., )])
125
    def test_spectogram_grad_at_zero(self, power):
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        """The gradient of power spectrogram should not be nan but zero near x=0

        https://github.com/pytorch/audio/issues/993
        """
        x = torch.zeros(1, 22050, requires_grad=True)
        spec = F.spectrogram(
            x,
            pad=0,
            window=None,
            n_fft=2048,
            hop_length=None,
            win_length=None,
            power=power,
            normalized=False,
        )
        spec.sum().backward()
        assert not x.grad.isnan().sum()
143

dhthompson's avatar
dhthompson committed
144
    def test_compute_deltas_one_channel(self):
145
146
        specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
147
148
149
150
151
        computed = F.compute_deltas(specgram, win_length=3)
        self.assertEqual(computed, expected)

    def test_compute_deltas_two_channels(self):
        specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
152
                                  [1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
153
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
154
                                  [0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        computed = F.compute_deltas(specgram, win_length=3)
        self.assertEqual(computed, expected)

    @parameterized.expand([(100,), (440,)])
    def test_detect_pitch_frequency_pitch(self, frequency):
        sample_rate = 44100
        test_sine_waveform = get_sinusoid(
            frequency=frequency, sample_rate=sample_rate, duration=5
        )

        freq = F.detect_pitch_frequency(test_sine_waveform, sample_rate)

        threshold = 1
        s = ((freq - frequency).abs() > threshold).sum()
        self.assertFalse(s)

    @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
    def test_amplitude_to_DB_reversible(self, shape):
        """Round trip between amplitude and db should return the original for various shape

        This implicitly also tests `DB_to_amplitude`.

        """
        amplitude_mult = 20.
        power_mult = 10.
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))

        torch.manual_seed(0)
185
        spec = torch.rand(*shape, dtype=self.dtype, device=self.device) * 200
dhthompson's avatar
dhthompson committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

        # Spectrogram amplitude -> DB -> amplitude
        db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, 0.5)

        self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)

        # Spectrogram power -> DB -> power
        db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, 1.)

        self.assertEqual(x2, spec)

    @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
    def test_amplitude_to_DB_top_db_clamp(self, shape):
        """Ensure values are properly clamped when `top_db` is supplied."""
        amplitude_mult = 20.
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))
        top_db = 40.

        torch.manual_seed(0)
        # A random tensor is used for increased entropy, but the max and min for
        # each spectrogram still need to be predictable. The max determines the
        # decibel cutoff, and the distance from the min must be large enough
        # that it triggers a clamp.
213
        spec = torch.rand(*shape, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Ensure each spectrogram has a min of 0 and a max of 1.
        spec -= spec.amin([-2, -1])[..., None, None]
        spec /= spec.amax([-2, -1])[..., None, None]
        # Expand the range to (0, 200) - wide enough to properly test clamping.
        spec *= 200

        decibels = F.amplitude_to_DB(spec, amplitude_mult, amin,
                                     db_mult, top_db=top_db)
        # Ensure the clamp was applied
        below_limit = decibels < 6.0205
        assert not below_limit.any(), (
            "{} decibel values were below the expected cutoff:\n{}".format(
                below_limit.sum().item(), decibels
            )
        )
        # Ensure it didn't over-clamp
        close_to_limit = decibels < 6.0207
        assert close_to_limit.any(), (
            f"No values were close to the limit. Did it over-clamp?\n{decibels}"
        )

    @parameterized.expand(
        list(itertools.product([(1, 2, 1025, 400, 2), (1025, 400, 2)], [1, 2, 0.7]))
    )
    def test_complex_norm(self, shape, power):
        torch.random.manual_seed(42)
240
        complex_tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
241
242
243
244
245
246
247
248
249
        expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
        norm_tensor = F.complex_norm(complex_tensor, power)
        self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)

    @parameterized.expand(
        list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
    )
    def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
        torch.random.manual_seed(42)
250
        specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)

        other_axis = 1 if axis == 2 else 2

        masked_columns = (mask_specgram == mask_value).sum(other_axis)
        num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
        num_masked_columns = torch.div(
            num_masked_columns, mask_specgram.size(0), rounding_mode='floor')

        assert mask_specgram.size() == specgram.size()
        assert num_masked_columns < mask_param

    @parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
    def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
        torch.random.manual_seed(42)
266
        specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
267
268
269
270
271
272
273
274
275
276

        mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

        other_axis = 2 if axis == 3 else 3

        masked_columns = (mask_specgrams == mask_value).sum(other_axis)
        num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)

        assert mask_specgrams.size() == specgrams.size()
        assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
277

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    @parameterized.expand(
        list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0., 30.], [1, 2]))
    )
    def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis):
        """mask_along_axis should not alter original input Tensor

        Test is run 5 times to bound the probability of no masking occurring to 1e-10
        See https://github.com/pytorch/audio/issues/1478
        """
        torch.random.manual_seed(42)
        for _ in range(5):
            specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
            specgram_copy = specgram.clone()
            F.mask_along_axis(specgram, mask_param, mask_value, axis)

            self.assertEqual(specgram, specgram_copy)

    @parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
    def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):
        """mask_along_axis_iid should not alter original input Tensor

        Test is run 5 times to bound the probability of no masking occurring to 1e-10
        See https://github.com/pytorch/audio/issues/1478
        """
        torch.random.manual_seed(42)
        for _ in range(5):
            specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
            specgrams_copy = specgrams.clone()
            F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

            self.assertEqual(specgrams, specgrams_copy)

310
311
312
313
314
315
316
317
318
319
    @parameterized.expand(list(itertools.product(
        ["sinc_interpolation", "kaiser_window"],
        [16000, 44100],
    )))
    def test_resample_identity(self, resampling_method, sample_rate):
        waveform = get_whitenoise(sample_rate=sample_rate, duration=1)

        resampled = F.resample(waveform, sample_rate, sample_rate)
        self.assertEqual(waveform, resampled)

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_upsample_size(self, resampling_method):
        sr = 16000
        waveform = get_whitenoise(sample_rate=sr, duration=0.5,)
        upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
        assert upsampled.size(-1) == waveform.size(-1) * 2

    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_downsample_size(self, resampling_method):
        sr = 16000
        waveform = get_whitenoise(sample_rate=sr, duration=0.5,)
        downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
        assert downsampled.size(-1) == waveform.size(-1) // 2

    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_identity_size(self, resampling_method):
        sr = 16000
        waveform = get_whitenoise(sample_rate=sr, duration=0.5,)
        resampled = F.resample(waveform, sr, sr, resampling_method=resampling_method)
        assert resampled.size(-1) == waveform.size(-1)

    @parameterized.expand(list(itertools.product(
        ["sinc_interpolation", "kaiser_window"],
        list(range(1, 20)),
    )))
    def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
        self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)

    @parameterized.expand(list(itertools.product(
        ["sinc_interpolation", "kaiser_window"],
        list(range(1, 20)),
    )))
    def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
        self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def test_resample_no_warning(self):
        sample_rate = 44100
        waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            F.resample(waveform, float(sample_rate), sample_rate / 2.)
        assert len(w) == 0

    def test_resample_warning(self):
        """resample should throw a warning if an input frequency is not of an integer value"""
        sample_rate = 44100
        waveform = get_whitenoise(sample_rate=sample_rate, duration=0.1)

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            F.resample(waveform, sample_rate, 5512.5)
        assert len(w) == 1

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    @nested_params(
        [0.5, 1.01, 1.3],
        [True, False],
    )
    def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
        """Verify the output shape of phase vocoder"""
        hop_length = 256
        num_freq = 1025
        num_frames = 400
        batch_size = 2

        torch.random.manual_seed(42)
        spec = torch.randn(
            batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
        if test_pseudo_complex:
            spec = torch.view_as_real(spec)

        phase_advance = torch.linspace(
            0,
            np.pi * hop_length,
            num_freq,
moto's avatar
moto committed
395
            dtype=self.dtype, device=self.device)[..., None]
396
397
398
399
400
401
402
403

        spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

        assert spec.dim() == spec_stretch.dim()
        expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
        output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
        assert output_shape == expected_shape

yangarbiter's avatar
yangarbiter committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    @parameterized.expand(
        [
            # words
            ["", "", 0],  # equal
            ["abc", "abc", 0],
            ["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0],

            ["abc", "", 3],  # deletion
            ["aa", "aaa", 1],
            ["aaa", "aa", 1],
            ["ᑌᑎI", "ᑌᑎIᑕO", 2],

            ["aaa", "aba", 1],  # substitution
            ["aba", "aaa", 1],
            ["aba", "   ", 3],

            ["abc", "bcd", 2],  # mix deletion and substitution
            ["0ᑌᑎI", "ᑌᑎIᑕO", 3],

            # sentences
            [["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0],  # equal
            [[], [], 0],

            [["hello", "world"], ["hello", "world", "!"], 1],  # deletion
            [["hello", "world"], ["world"], 1],
            [["hello", "world"], [], 2],

            [["Tᕮ᙭T", ], ["world"], 1],  # substitution
            [["Tᕮ᙭T", "XD"], ["world", "hello"], 2],
            [["", "XD"], ["world", ""], 2],
            ["aba", "   ", 3],

            [["hello", "world"], ["world", "hello", "!"], 2],  # mix deletion and substitution
            [["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3],
        ]
    )
    def test_simple_case_edit_distance(self, seq1, seq2, distance):
        assert F.edit_distance(seq1, seq2) == distance
        assert F.edit_distance(seq2, seq1) == distance

444
445
446
447
448
449
450
451
452
453
    @nested_params(
        [-4, -2, 0, 2, 4],
    )
    def test_pitch_shift_shape(self, n_steps):
        sample_rate = 16000
        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device)
        waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps)
        assert waveform.size() == waveform_shift.size()

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    def test_rnnt_loss_basic_backward(self):
        logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
        loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
        loss.backward()

    def test_rnnt_loss_basic_forward_no_grad(self):
        """In early stage, calls to `rnnt_loss` resulted in segmentation fault when
        `logits` have `requires_grad = False`. This test makes sure that this no longer
        occurs and the functional call runs without error.

        See https://github.com/pytorch/audio/pull/1707
        """
        logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
        logits.requires_grad_(False)
        F.rnnt_loss(logits, targets, logit_lengths, target_lengths)

    @parameterized.expand([
        (rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2),
        (rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2),
        (rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2),
        (rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
    ])
    def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
        data, ref_costs, ref_gradients = data_func(
            dtype=dtype,
            device=self.device,
        )
        self._test_costs_and_gradients(
            data=data,
            ref_costs=ref_costs,
            ref_gradients=ref_gradients,
            atol=atol,
            rtol=rtol,
        )

    def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
        seed = 777
        for i in range(5):
            data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
            ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
            self._test_costs_and_gradients(
                data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
            )

498
499

class FunctionalCPUOnly(TestBaseMixin):
500
    def test_melscale_fbanks_no_warning_high_n_freq(self):
501
502
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
503
            F.melscale_fbanks(288, 0, 8000, 128, 16000)
504
505
        assert len(w) == 0

506
    def test_melscale_fbanks_no_warning_low_n_mels(self):
507
508
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
509
            F.melscale_fbanks(201, 0, 8000, 89, 16000)
510
511
        assert len(w) == 0

512
    def test_melscale_fbanks_warning(self):
513
514
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
515
            F.melscale_fbanks(201, 0, 8000, 128, 16000)
516
        assert len(w) == 1