test_batch_consistency.py 3.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Test numerical consistency among single input and batched input."""
import os
import unittest

import torch
import torchaudio
import torchaudio.functional as F

import common_utils


def _test_batch_shape(functional, tensor, *args, **kwargs):

    kwargs_compare = {}
    if 'atol' in kwargs:
        atol = kwargs['atol']
        del kwargs['atol']
        kwargs_compare['atol'] = atol

    if 'rtol' in kwargs:
        rtol = kwargs['rtol']
        del kwargs['rtol']
        kwargs_compare['rtol'] = rtol

    # Single then transform then batch

    torch.random.manual_seed(42)
    expected = functional(tensor.clone(), *args, **kwargs)
    expected = expected.unsqueeze(0).unsqueeze(0)

    # 1-Batch then transform
    tensors = tensor.unsqueeze(0).unsqueeze(0)

    torch.random.manual_seed(42)
    computed = functional(tensors.clone(), *args, **kwargs)

    assert expected.shape == computed.shape, (expected.shape, computed.shape)
    assert torch.allclose(expected, computed, **kwargs_compare)

    return tensors, expected


def _test_batch(functional, tensor, *args, **kwargs):
    tensors, expected = _test_batch_shape(functional, tensor, *args, **kwargs)

    kwargs_compare = {}
    if 'atol' in kwargs:
        atol = kwargs['atol']
        del kwargs['atol']
        kwargs_compare['atol'] = atol

    if 'rtol' in kwargs:
        rtol = kwargs['rtol']
        del kwargs['rtol']
        kwargs_compare['rtol'] = rtol

    # 3-Batch then transform

    ind = [3] + [1] * (int(tensors.dim()) - 1)
    tensors = tensor.repeat(*ind)

    ind = [3] + [1] * (int(expected.dim()) - 1)
    expected = expected.repeat(*ind)

    torch.random.manual_seed(42)
    computed = functional(tensors.clone(), *args, **kwargs)


class TestFunctional(unittest.TestCase):
    """Test functions defined in `functional` module"""
    def test_griffinlim(self):
        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000
        tensor = torch.rand((1, 201, 6))
        _test_batch(
            F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
        )

    def test_detect_pitch_frequency(self):
        filenames = [
            'steam-train-whistle-daniel_simon.wav',  # 2ch 44100Hz
            # Files from https://www.mediacollege.com/audio/tone/download/
            '100Hz_44100Hz_16bit_05sec.wav',  # 1ch
            '440Hz_44100Hz_16bit_05sec.wav',  # 1ch
        ]
        for filename in filenames:
            filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', filename)
            waveform, sample_rate = torchaudio.load(filepath)
            _test_batch(F.detect_pitch_frequency, waveform, sample_rate)

    def test_istft(self):
        stft = torch.tensor([
            [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
            [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
        ])
        _test_batch(F.istft, stft, n_fft=4, length=4)