functional_impl.py 23.2 KB
Newer Older
dhthompson's avatar
dhthompson committed
1
2
"""Test definition common to CPU and CUDA"""
import itertools
3
import math
dhthompson's avatar
dhthompson committed
4
5
6
import warnings

import numpy as np
7
8
import torch
import torchaudio.functional as F
9
from parameterized import parameterized
10
from scipy import signal
11
12
13
14
15
16
17
from torchaudio_unittest.common_utils import (
    TestBaseMixin,
    get_sinusoid,
    nested_params,
    get_whitenoise,
    rnnt_utils,
)
18
19


dhthompson's avatar
dhthompson committed
20
class Functional(TestBaseMixin):
21
22
23
    def _test_resample_waveform_accuracy(
        self, up_scale_factor=None, down_scale_factor=None, resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4
    ):
24
25
26
27
28
29
        # resample the signal and compare it to the ground truth
        n_to_trim = 20
        sample_rate = 1000
        new_sample_rate = sample_rate

        if up_scale_factor is not None:
30
            new_sample_rate = int(new_sample_rate * up_scale_factor)
31
32

        if down_scale_factor is not None:
33
            new_sample_rate = int(new_sample_rate / down_scale_factor)
34
35
36
37
38

        duration = 5  # seconds
        original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

        sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
39
        estimate = F.resample(sound, sample_rate, new_sample_rate, resampling_method=resampling_method).squeeze()
40

41
        new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[: estimate.size(0)]
42
43
44
45
46
47
48
49
        ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

        # trim the first/last n samples as these points have boundary effects
        ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
        estimate = estimate[..., n_to_trim:-n_to_trim]

        self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)

50
    def _test_costs_and_gradients(self, data, ref_costs, ref_gradients, atol=1e-6, rtol=1e-2):
51
52
53
54
55
56
        logits_shape = data["logits"].shape
        costs, gradients = rnnt_utils.compute_with_pytorch_transducer(data=data)
        self.assertEqual(costs, ref_costs, atol=atol, rtol=rtol)
        self.assertEqual(logits_shape, gradients.shape)
        self.assertEqual(gradients, ref_gradients, atol=atol, rtol=rtol)

57
    def test_lfilter_simple(self):
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        """
        Create a very basic signal,
        Then make a simple 4th order delay
        The output should be same as the input but shifted
        """

        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device)
        b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
        a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)

        self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)

72
    def test_lfilter_clamp(self):
73
74
75
76
77
78
79
        input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
        b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
        a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
        output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True)
        assert output_signal.max() <= 1
        output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
        assert output_signal.max() > 1
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    @parameterized.expand(
        [
            ((44100,), (4,), (44100,)),
            (
                (3, 44100),
                (4,),
                (
                    3,
                    44100,
                ),
            ),
            (
                (2, 3, 44100),
                (4,),
                (
                    2,
                    3,
                    44100,
                ),
            ),
            (
                (1, 2, 3, 44100),
                (4,),
                (
                    1,
                    2,
                    3,
                    44100,
                ),
            ),
            ((44100,), (2, 4), (2, 44100)),
            ((3, 44100), (1, 4), (3, 1, 44100)),
            ((1, 2, 44100), (3, 4), (1, 2, 3, 44100)),
        ]
    )
116
    def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
117
        torch.random.manual_seed(42)
118
119
120
        waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
        b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
        a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
121
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False)
122
123
        assert input_shape == waveform.size()
        assert target_shape == output_waveform.size()
124

125
    def test_lfilter_9th_order_filter_stability(self):
126
127
128
129
130
131
132
133
134
        """
        Validate the precision of lfilter against reference scipy implementation when using high order filter.
        The reference implementation use cascaded second-order filters so is more numerically accurate.
        """
        # create an impulse signal
        x = torch.zeros(1024, dtype=self.dtype, device=self.device)
        x[0] = 1

        # get target impulse response
135
        sos = signal.butter(9, 850, "hp", fs=22050, output="sos")
136
137
138
        y = torch.from_numpy(signal.sosfilt(sos, x.cpu().numpy())).to(self.dtype).to(self.device)

        # get lfilter coefficients
139
140
        b, a = signal.butter(9, 850, "hp", fs=22050, output="ba")
        b, a = torch.from_numpy(b).to(self.dtype).to(self.device), torch.from_numpy(a).to(self.dtype).to(self.device)
141
142
143
144
145

        # predict impulse response
        yhat = F.lfilter(x, a, b, False)
        self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)

146
147
148
149
150
    def test_filtfilt_simple(self):
        """
        Check that, for an arbitrary signal, applying filtfilt with filter coefficients
        corresponding to a pure delay filter imparts no time delay.
        """
151
        waveform = get_whitenoise(sample_rate=8000, n_channels=2, dtype=self.dtype).to(device=self.device)
152
153
        b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
        a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
154
        padded_waveform = torch.cat((waveform, torch.zeros(2, 3, dtype=self.dtype, device=self.device)), axis=1)
155
156
157
158
159
160
161
162
163
164
165
166
167
        output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs)

        self.assertEqual(output_waveform, padded_waveform, atol=1e-5, rtol=1e-5)

    def test_filtfilt_filter_sinusoid(self):
        """
        Check that, for a signal comprising two sinusoids, applying filtfilt
        with appropriate filter coefficients correctly removes the higher-frequency
        sinusoid while imparting no time delay.
        """
        T = 1.0
        samples = 1000

168
169
170
        waveform_k0 = get_sinusoid(frequency=5, sample_rate=samples // T, dtype=self.dtype, device=self.device).squeeze(
            0
        )
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        waveform_k1 = get_sinusoid(
            frequency=200,
            sample_rate=samples // T,
            dtype=self.dtype,
            device=self.device,
        ).squeeze(0)
        waveform = waveform_k0 + waveform_k1

        # Transfer function numerator and denominator polynomial coefficients
        # corresponding to 8th-order Butterworth filter with 100-cycle/T cutoff.
        # Generated with
        # >>> from scipy import signal
        # >>> b_coeffs, a_coeffs = signal.butter(8, 0.2)
        b_coeffs = torch.tensor(
            [
                2.39596441e-05,
                1.91677153e-04,
                6.70870035e-04,
                1.34174007e-03,
                1.67717509e-03,
                1.34174007e-03,
                6.70870035e-04,
                1.91677153e-04,
                2.39596441e-05,
            ],
            dtype=self.dtype,
            device=self.device,
        )
        a_coeffs = torch.tensor(
            [
                1.0,
                -4.78451489,
                10.44504107,
                -13.45771989,
                11.12933104,
                -6.0252604,
                2.0792738,
                -0.41721716,
                0.0372001,
            ],
            dtype=self.dtype,
            device=self.device,
        )

        # Extend waveform in each direction, preserving periodicity.
        padded_waveform = torch.cat((waveform[:-1], waveform, waveform[1:]))

        output_waveform = F.filtfilt(padded_waveform, a_coeffs, b_coeffs)

        # Remove padding from output waveform; confirm that result
        # closely matches waveform_k0.
        self.assertEqual(
223
            output_waveform[samples - 1 : 2 * samples - 1],
224
225
226
227
228
            waveform_k0,
            atol=1e-3,
            rtol=1e-3,
        )

229
    @parameterized.expand([(0.0,), (1.0,), (2.0,), (3.0,)])
Caroline Chen's avatar
Caroline Chen committed
230
    def test_spectrogram_grad_at_zero(self, power):
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        """The gradient of power spectrogram should not be nan but zero near x=0

        https://github.com/pytorch/audio/issues/993
        """
        x = torch.zeros(1, 22050, requires_grad=True)
        spec = F.spectrogram(
            x,
            pad=0,
            window=None,
            n_fft=2048,
            hop_length=None,
            win_length=None,
            power=power,
            normalized=False,
        )
        spec.sum().backward()
        assert not x.grad.isnan().sum()
248

dhthompson's avatar
dhthompson committed
249
    def test_compute_deltas_one_channel(self):
250
251
        specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
252
253
254
255
        computed = F.compute_deltas(specgram, win_length=3)
        self.assertEqual(computed, expected)

    def test_compute_deltas_two_channels(self):
256
257
        specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
258
259
260
261
262
263
        computed = F.compute_deltas(specgram, win_length=3)
        self.assertEqual(computed, expected)

    @parameterized.expand([(100,), (440,)])
    def test_detect_pitch_frequency_pitch(self, frequency):
        sample_rate = 44100
264
        test_sine_waveform = get_sinusoid(frequency=frequency, sample_rate=sample_rate, duration=5)
dhthompson's avatar
dhthompson committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278

        freq = F.detect_pitch_frequency(test_sine_waveform, sample_rate)

        threshold = 1
        s = ((freq - frequency).abs() > threshold).sum()
        self.assertFalse(s)

    @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
    def test_amplitude_to_DB_reversible(self, shape):
        """Round trip between amplitude and db should return the original for various shape

        This implicitly also tests `DB_to_amplitude`.

        """
279
280
        amplitude_mult = 20.0
        power_mult = 10.0
dhthompson's avatar
dhthompson committed
281
282
283
284
285
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))

        torch.manual_seed(0)
286
        spec = torch.rand(*shape, dtype=self.dtype, device=self.device) * 200
dhthompson's avatar
dhthompson committed
287
288
289
290
291
292
293
294
295

        # Spectrogram amplitude -> DB -> amplitude
        db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, 0.5)

        self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)

        # Spectrogram power -> DB -> power
        db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
296
        x2 = F.DB_to_amplitude(db, ref, 1.0)
dhthompson's avatar
dhthompson committed
297
298
299
300
301
302

        self.assertEqual(x2, spec)

    @parameterized.expand([([100, 100],), ([2, 100, 100],), ([2, 2, 100, 100],)])
    def test_amplitude_to_DB_top_db_clamp(self, shape):
        """Ensure values are properly clamped when `top_db` is supplied."""
303
        amplitude_mult = 20.0
dhthompson's avatar
dhthompson committed
304
305
306
        amin = 1e-10
        ref = 1.0
        db_mult = math.log10(max(amin, ref))
307
        top_db = 40.0
dhthompson's avatar
dhthompson committed
308
309
310
311
312
313

        torch.manual_seed(0)
        # A random tensor is used for increased entropy, but the max and min for
        # each spectrogram still need to be predictable. The max determines the
        # decibel cutoff, and the distance from the min must be large enough
        # that it triggers a clamp.
314
        spec = torch.rand(*shape, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
315
316
317
318
319
320
        # Ensure each spectrogram has a min of 0 and a max of 1.
        spec -= spec.amin([-2, -1])[..., None, None]
        spec /= spec.amax([-2, -1])[..., None, None]
        # Expand the range to (0, 200) - wide enough to properly test clamping.
        spec *= 200

321
        decibels = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=top_db)
dhthompson's avatar
dhthompson committed
322
323
        # Ensure the clamp was applied
        below_limit = decibels < 6.0205
324
325
        assert not below_limit.any(), "{} decibel values were below the expected cutoff:\n{}".format(
            below_limit.sum().item(), decibels
dhthompson's avatar
dhthompson committed
326
327
328
        )
        # Ensure it didn't over-clamp
        close_to_limit = decibels < 6.0207
329
        assert close_to_limit.any(), f"No values were close to the limit. Did it over-clamp?\n{decibels}"
dhthompson's avatar
dhthompson committed
330

331
    @parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2])))
dhthompson's avatar
dhthompson committed
332
333
    def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
        torch.random.manual_seed(42)
334
        specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
335
336
337
338
339
340
        mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)

        other_axis = 1 if axis == 2 else 2

        masked_columns = (mask_specgram == mask_value).sum(other_axis)
        num_masked_columns = (masked_columns == mask_specgram.size(other_axis)).sum()
341
        num_masked_columns = torch.div(num_masked_columns, mask_specgram.size(0), rounding_mode="floor")
dhthompson's avatar
dhthompson committed
342
343
344
345

        assert mask_specgram.size() == specgram.size()
        assert num_masked_columns < mask_param

346
    @parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3])))
dhthompson's avatar
dhthompson committed
347
348
    def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
        torch.random.manual_seed(42)
349
        specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
dhthompson's avatar
dhthompson committed
350
351
352
353
354
355
356
357
358
359

        mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

        other_axis = 2 if axis == 3 else 3

        masked_columns = (mask_specgrams == mask_value).sum(other_axis)
        num_masked_columns = (masked_columns == mask_specgrams.size(other_axis)).sum(-1)

        assert mask_specgrams.size() == specgrams.size()
        assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
360

361
    @parameterized.expand(list(itertools.product([(2, 1025, 400), (1, 201, 100)], [100], [0.0, 30.0], [1, 2])))
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def test_mask_along_axis_preserve(self, shape, mask_param, mask_value, axis):
        """mask_along_axis should not alter original input Tensor

        Test is run 5 times to bound the probability of no masking occurring to 1e-10
        See https://github.com/pytorch/audio/issues/1478
        """
        torch.random.manual_seed(42)
        for _ in range(5):
            specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
            specgram_copy = specgram.clone()
            F.mask_along_axis(specgram, mask_param, mask_value, axis)

            self.assertEqual(specgram, specgram_copy)

376
    @parameterized.expand(list(itertools.product([100], [0.0, 30.0], [2, 3])))
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    def test_mask_along_axis_iid_preserve(self, mask_param, mask_value, axis):
        """mask_along_axis_iid should not alter original input Tensor

        Test is run 5 times to bound the probability of no masking occurring to 1e-10
        See https://github.com/pytorch/audio/issues/1478
        """
        torch.random.manual_seed(42)
        for _ in range(5):
            specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
            specgrams_copy = specgrams.clone()
            F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

            self.assertEqual(specgrams, specgrams_copy)

391
392
393
394
395
396
397
398
    @parameterized.expand(
        list(
            itertools.product(
                ["sinc_interpolation", "kaiser_window"],
                [16000, 44100],
            )
        )
    )
399
400
401
402
403
404
    def test_resample_identity(self, resampling_method, sample_rate):
        waveform = get_whitenoise(sample_rate=sample_rate, duration=1)

        resampled = F.resample(waveform, sample_rate, sample_rate)
        self.assertEqual(waveform, resampled)

405
406
407
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_upsample_size(self, resampling_method):
        sr = 16000
408
409
410
411
        waveform = get_whitenoise(
            sample_rate=sr,
            duration=0.5,
        )
412
413
414
415
416
417
        upsampled = F.resample(waveform, sr, sr * 2, resampling_method=resampling_method)
        assert upsampled.size(-1) == waveform.size(-1) * 2

    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_downsample_size(self, resampling_method):
        sr = 16000
418
419
420
421
        waveform = get_whitenoise(
            sample_rate=sr,
            duration=0.5,
        )
422
423
424
425
426
427
        downsampled = F.resample(waveform, sr, sr // 2, resampling_method=resampling_method)
        assert downsampled.size(-1) == waveform.size(-1) // 2

    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_identity_size(self, resampling_method):
        sr = 16000
428
429
430
431
        waveform = get_whitenoise(
            sample_rate=sr,
            duration=0.5,
        )
432
433
434
        resampled = F.resample(waveform, sr, sr, resampling_method=resampling_method)
        assert resampled.size(-1) == waveform.size(-1)

435
436
437
438
439
440
441
442
    @parameterized.expand(
        list(
            itertools.product(
                ["sinc_interpolation", "kaiser_window"],
                list(range(1, 20)),
            )
        )
    )
443
444
445
    def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
        self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)

446
447
448
449
450
451
452
453
    @parameterized.expand(
        list(
            itertools.product(
                ["sinc_interpolation", "kaiser_window"],
                list(range(1, 20)),
            )
        )
    )
454
455
456
    def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
        self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

457
458
    @nested_params([0.5, 1.01, 1.3])
    def test_phase_vocoder_shape(self, rate):
459
460
461
462
463
464
465
        """Verify the output shape of phase vocoder"""
        hop_length = 256
        num_freq = 1025
        num_frames = 400
        batch_size = 2

        torch.random.manual_seed(42)
466
        spec = torch.randn(batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
467

468
        phase_advance = torch.linspace(0, np.pi * hop_length, num_freq, dtype=self.dtype, device=self.device)[..., None]
469
470
471
472
473

        spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

        assert spec.dim() == spec_stretch.dim()
        expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
474
        output_shape = spec_stretch.shape
475
476
        assert output_shape == expected_shape

yangarbiter's avatar
yangarbiter committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    @parameterized.expand(
        [
            # words
            ["", "", 0],  # equal
            ["abc", "abc", 0],
            ["ᑌᑎIᑕO", "ᑌᑎIᑕO", 0],
            ["abc", "", 3],  # deletion
            ["aa", "aaa", 1],
            ["aaa", "aa", 1],
            ["ᑌᑎI", "ᑌᑎIᑕO", 2],
            ["aaa", "aba", 1],  # substitution
            ["aba", "aaa", 1],
            ["aba", "   ", 3],
            ["abc", "bcd", 2],  # mix deletion and substitution
            ["0ᑌᑎI", "ᑌᑎIᑕO", 3],
            # sentences
            [["hello", "", "Tᕮ᙭T"], ["hello", "", "Tᕮ᙭T"], 0],  # equal
            [[], [], 0],
            [["hello", "world"], ["hello", "world", "!"], 1],  # deletion
            [["hello", "world"], ["world"], 1],
            [["hello", "world"], [], 2],
498
499
500
501
502
503
504
            [
                [
                    "Tᕮ᙭T",
                ],
                ["world"],
                1,
            ],  # substitution
yangarbiter's avatar
yangarbiter committed
505
506
507
508
509
510
511
512
513
514
515
            [["Tᕮ᙭T", "XD"], ["world", "hello"], 2],
            [["", "XD"], ["world", ""], 2],
            ["aba", "   ", 3],
            [["hello", "world"], ["world", "hello", "!"], 2],  # mix deletion and substitution
            [["Tᕮ᙭T", "world", "LOL", "XD"], ["world", "hello", "ʕ•́ᴥ•̀ʔっ"], 3],
        ]
    )
    def test_simple_case_edit_distance(self, seq1, seq2, distance):
        assert F.edit_distance(seq1, seq2) == distance
        assert F.edit_distance(seq2, seq1) == distance

516
517
518
519
520
521
522
523
524
525
    @nested_params(
        [-4, -2, 0, 2, 4],
    )
    def test_pitch_shift_shape(self, n_steps):
        sample_rate = 16000
        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=self.dtype, device=self.device)
        waveform_shift = F.pitch_shift(waveform, sample_rate, n_steps)
        assert waveform.size() == waveform_shift.size()

526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    def test_rnnt_loss_basic_backward(self):
        logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
        loss = F.rnnt_loss(logits, targets, logit_lengths, target_lengths)
        loss.backward()

    def test_rnnt_loss_basic_forward_no_grad(self):
        """In early stage, calls to `rnnt_loss` resulted in segmentation fault when
        `logits` have `requires_grad = False`. This test makes sure that this no longer
        occurs and the functional call runs without error.

        See https://github.com/pytorch/audio/pull/1707
        """
        logits, targets, logit_lengths, target_lengths = rnnt_utils.get_basic_data(self.device)
        logits.requires_grad_(False)
        F.rnnt_loss(logits, targets, logit_lengths, target_lengths)

542
543
544
545
546
547
548
549
    @parameterized.expand(
        [
            (rnnt_utils.get_B1_T2_U3_D5_data, torch.float32, 1e-6, 1e-2),
            (rnnt_utils.get_B2_T4_U3_D3_data, torch.float32, 1e-6, 1e-2),
            (rnnt_utils.get_B1_T2_U3_D5_data, torch.float16, 1e-3, 1e-2),
            (rnnt_utils.get_B2_T4_U3_D3_data, torch.float16, 1e-3, 1e-2),
        ]
    )
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    def test_rnnt_loss_costs_and_gradients(self, data_func, dtype, atol, rtol):
        data, ref_costs, ref_gradients = data_func(
            dtype=dtype,
            device=self.device,
        )
        self._test_costs_and_gradients(
            data=data,
            ref_costs=ref_costs,
            ref_gradients=ref_gradients,
            atol=atol,
            rtol=rtol,
        )

    def test_rnnt_loss_costs_and_gradients_random_data_with_numpy_fp32(self):
        seed = 777
        for i in range(5):
            data = rnnt_utils.get_random_data(dtype=torch.float32, device=self.device, seed=(seed + i))
            ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
568
            self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)
569

570
571

class FunctionalCPUOnly(TestBaseMixin):
572
    def test_melscale_fbanks_no_warning_high_n_freq(self):
573
574
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
575
            F.melscale_fbanks(288, 0, 8000, 128, 16000)
576
577
        assert len(w) == 0

578
    def test_melscale_fbanks_no_warning_low_n_mels(self):
579
580
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
581
            F.melscale_fbanks(201, 0, 8000, 89, 16000)
582
583
        assert len(w) == 0

584
    def test_melscale_fbanks_warning(self):
585
586
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
587
            F.melscale_fbanks(201, 0, 8000, 128, 16000)
588
        assert len(w) == 1