test_kaldi.py 8.5 KB
Newer Older
1
2
3
import math
import os
import test.common_utils
jamarshon's avatar
jamarshon committed
4
import test.compliance.utils
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
    # just a copy of ExtractWindow from feature-window.cc in python
    def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
        if snip_edges:
            return frame * window_shift
        else:
            midpoint_of_frame = frame * window_shift + window_shift // 2
            beginning_of_frame = midpoint_of_frame - window_size // 2
            return beginning_of_frame

    sample_offset = 0
    num_samples = sample_offset + wave.size(0)
    start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges)
    end_sample = start_sample + frame_length

    if snip_edges:
        assert(start_sample >= sample_offset and end_sample <= num_samples)
    else:
        assert(sample_offset == 0 or start_sample >= sample_offset)

    wave_start = start_sample - sample_offset
    wave_end = wave_start + frame_length
    if wave_start >= 0 and wave_end <= wave.size(0):
        window[f, :] = wave[wave_start:(wave_start + frame_length)]
    else:
        wave_dim = wave.size(0)
        for s in range(frame_length):
            s_in_wave = s + wave_start
            while s_in_wave < 0 or s_in_wave >= wave_dim:
                if s_in_wave < 0:
                    s_in_wave = - s_in_wave - 1
                else:
                    s_in_wave = 2 * wave_dim - 1 - s_in_wave
            window[f, s] = wave[s_in_wave]


class Test_Kaldi(unittest.TestCase):
    test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
jamarshon's avatar
jamarshon committed
50
51
52
53
54
55
56
57
58
59
    kaldi_output_dir = os.path.join(test_dirpath, 'assets', 'kaldi')
    test_filepaths = {'spec': [], 'fbank': []}

    # separating test files by their types (e.g 'spec', 'fbank', etc.)
    for f in os.listdir(kaldi_output_dir):
        dash_idx = f.find('-')
        assert f.endswith('.ark') and dash_idx != -1
        key = f[:dash_idx]
        assert key in test_filepaths
        test_filepaths[key].append(f)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
        waveform = torch.arange(num_samples).float()
        output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)

        # from NumFrames in feature-window.cc
        n = window_size
        if snip_edges:
            m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift
        else:
            m = (num_samples + (window_shift // 2)) // window_shift

        self.assertTrue(output.dim() == 2)
        self.assertTrue(output.shape[0] == m and output.shape[1] == n)

        window = torch.empty((m, window_size))

        for r in range(m):
            extract_window(window, waveform, r, window_size, window_shift, snip_edges)
        self.assertTrue(torch.allclose(window, output))

    def test_get_strided(self):
        # generate any combination where 0 < window_size <= num_samples and
        # 0 < window_shift.
        for num_samples in range(1, 20):
            for window_size in range(1, num_samples + 1):
                for window_shift in range(1, 2 * num_samples + 1):
                    for snip_edges in range(0, 2):
                        self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)

    def _create_data_set(self):
        # used to generate the dataset to test on. this is not used in testing (offline procedure)
        test_dirpath = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
        test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
        sr = 16000
        x = torch.arange(0, 20).float()
        # between [-6,6]
        y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
        # between [-2^30, 2^30]
        y = (y / 6 * (1 << 30)).long()
        # clear the last 16 bits because they aren't used anyways
        y = ((y >> 16) << 16).float()
        torchaudio.save(test_filepath, y, sr)
        sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
        print(y >> 16)
        self.assertTrue(sample_rate == sr)
        self.assertTrue(torch.allclose(y, sound))

jamarshon's avatar
jamarshon committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def _print_diagnostic(self, output, expect_output):
        # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
        abs_error = output - expect_output
        abs_mse = abs_error.pow(2).sum() / output.numel()
        abs_max_error = torch.max(abs_error.abs())

        relative_error = abs_error / expect_output
        relative_mse = relative_error.pow(2).sum() / output.numel()
        relative_max_error = torch.max(relative_error.abs())

        print('abs_mse:', abs_mse.item(), 'abs_max_error:', abs_max_error.item())
        print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())

    def _compliance_test_helper(self, filepath_key, expected_num_files, expected_num_args, get_output_fn):
        """
        Inputs:
            filepath_key (str): A key to `test_filepaths` which matches which files to use
            expected_num_files (int): The expected number of kaldi files to read
            expected_num_args (int): The expected number of arguments used in a kaldi configuration
            get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
                and a configuration and returns an output
        """
130
        sound, sample_rate = torchaudio.load_wav(self.test_filepath)
jamarshon's avatar
jamarshon committed
131
132
133
        files = self.test_filepaths[filepath_key]

        assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
134
135
136

        for f in files:
            print(f)
jamarshon's avatar
jamarshon committed
137
138
139

            # Read kaldi's output from file
            kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
140
141
142
143
144
            kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)}

            assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

jamarshon's avatar
jamarshon committed
145
            # Construct the same configuration used by kaldi
146
147
            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
jamarshon's avatar
jamarshon committed
148
149
            assert len(args) == expected_num_args, 'invalid test kaldi file name'
            args = [test.compliance.utils.parse(arg) for arg in args]
150

jamarshon's avatar
jamarshon committed
151
152
153
154
155
156
157
158
159
            output = get_output_fn(sound, args)

            self._print_diagnostic(output, kaldi_output)
            self.assertTrue(output.shape, kaldi_output.shape)
            self.assertTrue(torch.allclose(output, kaldi_output, atol=1e-3, rtol=1e-1))

    def test_spectrogram(self):
        def get_output_fn(sound, args):
            output = kaldi.spectrogram(
160
                sound,
jamarshon's avatar
jamarshon committed
161
162
163
164
165
166
167
168
169
170
171
                blackman_coeff=args[1],
                dither=args[2],
                energy_floor=args[3],
                frame_length=args[4],
                frame_shift=args[5],
                preemphasis_coefficient=args[6],
                raw_energy=args[7],
                remove_dc_offset=args[8],
                round_to_power_of_two=args[9],
                snip_edges=args[10],
                subtract_mean=args[11],
172
                window_type=args[12])
jamarshon's avatar
jamarshon committed
173
            return output
174

jamarshon's avatar
jamarshon committed
175
        self._compliance_test_helper('spec', 131, 13, get_output_fn)
176

jamarshon's avatar
jamarshon committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def test_fbank(self):
        def get_output_fn(sound, args):
            output = kaldi.fbank(
                sound,
                blackman_coeff=args[1],
                dither=0.0,
                energy_floor=args[2],
                frame_length=args[3],
                frame_shift=args[4],
                high_freq=args[5],
                htk_compat=args[6],
                low_freq=args[7],
                num_mel_bins=args[8],
                preemphasis_coefficient=args[9],
                raw_energy=args[10],
                remove_dc_offset=args[11],
                round_to_power_of_two=args[12],
                snip_edges=args[13],
                subtract_mean=args[14],
                use_energy=args[15],
                use_log_fbank=args[16],
                use_power=args[17],
                vtln_high=args[18],
                vtln_low=args[19],
                vtln_warp=args[20],
                window_type=args[21])
            return output

        self._compliance_test_helper('fbank', 97, 22, get_output_fn)
206
207
208
209


if __name__ == '__main__':
    unittest.main()