compliance_kaldi_test.py 6.92 KB
Newer Older
1
import os
moto's avatar
moto committed
2
3
import math

4
5
6
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
7

8
from torchaudio_unittest import common_utils
9
from .compliance import utils as compliance_utils
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
    # just a copy of ExtractWindow from feature-window.cc in python
    def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
        if snip_edges:
            return frame * window_shift
        else:
            midpoint_of_frame = frame * window_shift + window_shift // 2
            beginning_of_frame = midpoint_of_frame - window_size // 2
            return beginning_of_frame

    sample_offset = 0
    num_samples = sample_offset + wave.size(0)
    start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges)
    end_sample = start_sample + frame_length

    if snip_edges:
        assert(start_sample >= sample_offset and end_sample <= num_samples)
    else:
        assert(sample_offset == 0 or start_sample >= sample_offset)

    wave_start = start_sample - sample_offset
    wave_end = wave_start + frame_length
    if wave_start >= 0 and wave_end <= wave.size(0):
        window[f, :] = wave[wave_start:(wave_start + frame_length)]
    else:
        wave_dim = wave.size(0)
        for s in range(frame_length):
            s_in_wave = s + wave_start
            while s_in_wave < 0 or s_in_wave >= wave_dim:
                if s_in_wave < 0:
                    s_in_wave = - s_in_wave - 1
                else:
                    s_in_wave = 2 * wave_dim - 1 - s_in_wave
            window[f, s] = wave[s_in_wave]


48
@common_utils.skipIfNoSox
49
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
moto's avatar
moto committed
50

51
    kaldi_output_dir = common_utils.get_asset_path('kaldi')
52
    test_filepath = common_utils.get_asset_path('kaldi_file.wav')
53
    test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX}
jamarshon's avatar
jamarshon committed
54
55
56
57
58
59
60
61

    # separating test files by their types (e.g 'spec', 'fbank', etc.)
    for f in os.listdir(kaldi_output_dir):
        dash_idx = f.find('-')
        assert f.endswith('.ark') and dash_idx != -1
        key = f[:dash_idx]
        assert key in test_filepaths
        test_filepaths[key].append(f)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
        waveform = torch.arange(num_samples).float()
        output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)

        # from NumFrames in feature-window.cc
        n = window_size
        if snip_edges:
            m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift
        else:
            m = (num_samples + (window_shift // 2)) // window_shift

        self.assertTrue(output.dim() == 2)
        self.assertTrue(output.shape[0] == m and output.shape[1] == n)

        window = torch.empty((m, window_size))

        for r in range(m):
            extract_window(window, waveform, r, window_size, window_shift, snip_edges)
81
        self.assertEqual(window, output)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    def test_get_strided(self):
        # generate any combination where 0 < window_size <= num_samples and
        # 0 < window_shift.
        for num_samples in range(1, 20):
            for window_size in range(1, num_samples + 1):
                for window_shift in range(1, 2 * num_samples + 1):
                    for snip_edges in range(0, 2):
                        self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)

    def _create_data_set(self):
        # used to generate the dataset to test on. this is not used in testing (offline procedure)
        sr = 16000
        x = torch.arange(0, 20).float()
        # between [-6,6]
        y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
        # between [-2^30, 2^30]
        y = (y / 6 * (1 << 30)).long()
        # clear the last 16 bits because they aren't used anyways
        y = ((y >> 16) << 16).float()
102
        torchaudio.save(self.test_filepath, y, sr)
103
        sound, sample_rate = common_utils.load_wav(self.test_filepath, normalize=False)
104
105
        print(y >> 16)
        self.assertTrue(sample_rate == sr)
106
        self.assertEqual(y, sound)
107

jamarshon's avatar
jamarshon committed
108
109
110
111
112
113
114
115
116
117
118
119
120
    def _print_diagnostic(self, output, expect_output):
        # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
        abs_error = output - expect_output
        abs_mse = abs_error.pow(2).sum() / output.numel()
        abs_max_error = torch.max(abs_error.abs())

        relative_error = abs_error / expect_output
        relative_mse = relative_error.pow(2).sum() / output.numel()
        relative_max_error = torch.max(relative_error.abs())

        print('abs_mse:', abs_mse.item(), 'abs_max_error:', abs_max_error.item())
        print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())

jamarshon's avatar
jamarshon committed
121
    def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files,
122
                                expected_num_args, get_output_fn, atol=1e-5, rtol=1e-7):
jamarshon's avatar
jamarshon committed
123
124
        """
        Inputs:
jamarshon's avatar
jamarshon committed
125
            sound_filepath (str): The location of the sound file
jamarshon's avatar
jamarshon committed
126
127
128
129
130
            filepath_key (str): A key to `test_filepaths` which matches which files to use
            expected_num_files (int): The expected number of kaldi files to read
            expected_num_args (int): The expected number of arguments used in a kaldi configuration
            get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
                and a configuration and returns an output
jamarshon's avatar
jamarshon committed
131
132
            atol (float): absolute tolerance
            rtol (float): relative tolerance
jamarshon's avatar
jamarshon committed
133
        """
134
        sound, sr = common_utils.load_wav(sound_filepath, normalize=False)
jamarshon's avatar
jamarshon committed
135
136
        files = self.test_filepaths[filepath_key]

137
138
139
        assert len(files) == expected_num_files, \
            ('number of kaldi {} file changed to {}'.format(
                filepath_key, len(files)))
140
141
142

        for f in files:
            print(f)
jamarshon's avatar
jamarshon committed
143
144
145

            # Read kaldi's output from file
            kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
146
            kaldi_output_dict = dict(torchaudio.kaldi_io.read_mat_ark(kaldi_output_path))
147
148
149
150

            assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

jamarshon's avatar
jamarshon committed
151
            # Construct the same configuration used by kaldi
152
153
            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
jamarshon's avatar
jamarshon committed
154
            assert len(args) == expected_num_args, 'invalid test kaldi file name'
155
            args = [compliance_utils.parse(arg) for arg in args]
156

jamarshon's avatar
jamarshon committed
157
158
159
            output = get_output_fn(sound, args)

            self._print_diagnostic(output, kaldi_output)
160
            self.assertEqual(output, kaldi_output, atol=atol, rtol=rtol)
jamarshon's avatar
jamarshon committed
161

jamarshon's avatar
jamarshon committed
162
163
164
    def test_mfcc_empty(self):
        # Passing in an empty tensor should result in an error
        self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))