compliance_kaldi_test.py 11.4 KB
Newer Older
1
import os
moto's avatar
moto committed
2
3
import math

4
5
6
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
7

8
from torchaudio_unittest import common_utils
9
from .compliance import utils as compliance_utils
10
from parameterized import parameterized
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
    # just a copy of ExtractWindow from feature-window.cc in python
    def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
        if snip_edges:
            return frame * window_shift
        else:
            midpoint_of_frame = frame * window_shift + window_shift // 2
            beginning_of_frame = midpoint_of_frame - window_size // 2
            return beginning_of_frame

    sample_offset = 0
    num_samples = sample_offset + wave.size(0)
    start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges)
    end_sample = start_sample + frame_length

    if snip_edges:
        assert(start_sample >= sample_offset and end_sample <= num_samples)
    else:
        assert(sample_offset == 0 or start_sample >= sample_offset)

    wave_start = start_sample - sample_offset
    wave_end = wave_start + frame_length
    if wave_start >= 0 and wave_end <= wave.size(0):
        window[f, :] = wave[wave_start:(wave_start + frame_length)]
    else:
        wave_dim = wave.size(0)
        for s in range(frame_length):
            s_in_wave = s + wave_start
            while s_in_wave < 0 or s_in_wave >= wave_dim:
                if s_in_wave < 0:
                    s_in_wave = - s_in_wave - 1
                else:
                    s_in_wave = 2 * wave_dim - 1 - s_in_wave
            window[f, s] = wave[s_in_wave]


49
@common_utils.skipIfNoSox
50
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
moto's avatar
moto committed
51

52
    kaldi_output_dir = common_utils.get_asset_path('kaldi')
53
    test_filepath = common_utils.get_asset_path('kaldi_file.wav')
54
    test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX}
jamarshon's avatar
jamarshon committed
55

56
57
58
    def setUp(self):
        super().setUp()

59
60
61
62
        # test signal for testing resampling
        self.test_signal_sr = 16000
        self.test_signal = common_utils.get_whitenoise(
            sample_rate=self.test_signal_sr, duration=0.5,
63
64
        )

jamarshon's avatar
jamarshon committed
65
66
67
68
69
70
71
    # separating test files by their types (e.g 'spec', 'fbank', etc.)
    for f in os.listdir(kaldi_output_dir):
        dash_idx = f.find('-')
        assert f.endswith('.ark') and dash_idx != -1
        key = f[:dash_idx]
        assert key in test_filepaths
        test_filepaths[key].append(f)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
        waveform = torch.arange(num_samples).float()
        output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)

        # from NumFrames in feature-window.cc
        n = window_size
        if snip_edges:
            m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift
        else:
            m = (num_samples + (window_shift // 2)) // window_shift

        self.assertTrue(output.dim() == 2)
        self.assertTrue(output.shape[0] == m and output.shape[1] == n)

        window = torch.empty((m, window_size))

        for r in range(m):
            extract_window(window, waveform, r, window_size, window_shift, snip_edges)
91
        self.assertEqual(window, output)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    def test_get_strided(self):
        # generate any combination where 0 < window_size <= num_samples and
        # 0 < window_shift.
        for num_samples in range(1, 20):
            for window_size in range(1, num_samples + 1):
                for window_shift in range(1, 2 * num_samples + 1):
                    for snip_edges in range(0, 2):
                        self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)

    def _create_data_set(self):
        # used to generate the dataset to test on. this is not used in testing (offline procedure)
        sr = 16000
        x = torch.arange(0, 20).float()
        # between [-6,6]
        y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
        # between [-2^30, 2^30]
        y = (y / 6 * (1 << 30)).long()
        # clear the last 16 bits because they aren't used anyways
        y = ((y >> 16) << 16).float()
112
        torchaudio.save(self.test_filepath, y, sr)
113
        sound, sample_rate = common_utils.load_wav(self.test_filepath, normalize=False)
114
115
        print(y >> 16)
        self.assertTrue(sample_rate == sr)
116
        self.assertEqual(y, sound)
117

jamarshon's avatar
jamarshon committed
118
119
120
121
122
123
124
125
126
127
128
129
130
    def _print_diagnostic(self, output, expect_output):
        # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
        abs_error = output - expect_output
        abs_mse = abs_error.pow(2).sum() / output.numel()
        abs_max_error = torch.max(abs_error.abs())

        relative_error = abs_error / expect_output
        relative_mse = relative_error.pow(2).sum() / output.numel()
        relative_max_error = torch.max(relative_error.abs())

        print('abs_mse:', abs_mse.item(), 'abs_max_error:', abs_max_error.item())
        print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())

jamarshon's avatar
jamarshon committed
131
    def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files,
132
                                expected_num_args, get_output_fn, atol=1e-5, rtol=1e-7):
jamarshon's avatar
jamarshon committed
133
134
        """
        Inputs:
jamarshon's avatar
jamarshon committed
135
            sound_filepath (str): The location of the sound file
jamarshon's avatar
jamarshon committed
136
137
138
139
140
            filepath_key (str): A key to `test_filepaths` which matches which files to use
            expected_num_files (int): The expected number of kaldi files to read
            expected_num_args (int): The expected number of arguments used in a kaldi configuration
            get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
                and a configuration and returns an output
jamarshon's avatar
jamarshon committed
141
142
            atol (float): absolute tolerance
            rtol (float): relative tolerance
jamarshon's avatar
jamarshon committed
143
        """
144
        sound, sr = common_utils.load_wav(sound_filepath, normalize=False)
jamarshon's avatar
jamarshon committed
145
146
        files = self.test_filepaths[filepath_key]

147
148
149
        assert len(files) == expected_num_files, \
            ('number of kaldi {} file changed to {}'.format(
                filepath_key, len(files)))
150
151
152

        for f in files:
            print(f)
jamarshon's avatar
jamarshon committed
153
154
155

            # Read kaldi's output from file
            kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
156
            kaldi_output_dict = dict(torchaudio.kaldi_io.read_mat_ark(kaldi_output_path))
157
158
159
160

            assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

jamarshon's avatar
jamarshon committed
161
            # Construct the same configuration used by kaldi
162
163
            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
jamarshon's avatar
jamarshon committed
164
            assert len(args) == expected_num_args, 'invalid test kaldi file name'
165
            args = [compliance_utils.parse(arg) for arg in args]
166

jamarshon's avatar
jamarshon committed
167
168
169
            output = get_output_fn(sound, args)

            self._print_diagnostic(output, kaldi_output)
170
            self.assertEqual(output, kaldi_output, atol=atol, rtol=rtol)
jamarshon's avatar
jamarshon committed
171

jamarshon's avatar
jamarshon committed
172
173
174
175
    def test_mfcc_empty(self):
        # Passing in an empty tensor should result in an error
        self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

176
177
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_upsample_size(self, resampling_method):
178
        upsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr * 2,
179
                                                 resampling_method=resampling_method)
180
        self.assertTrue(upsample_sound.size(-1) == self.test_signal.size(-1) * 2)
jamarshon's avatar
jamarshon committed
181

182
183
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_downsample_size(self, resampling_method):
184
        downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr // 2,
185
                                                   resampling_method=resampling_method)
186
        self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1) // 2)
jamarshon's avatar
jamarshon committed
187

188
189
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_identity_size(self, resampling_method):
190
        downsample_sound = kaldi.resample_waveform(self.test_signal, self.test_signal_sr, self.test_signal_sr,
191
                                                   resampling_method=resampling_method)
192
        self.assertTrue(downsample_sound.size(-1) == self.test_signal.size(-1))
jamarshon's avatar
jamarshon committed
193
194

    def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
195
                                         resampling_method="sinc_interpolation", atol=1e-1, rtol=1e-4):
jamarshon's avatar
jamarshon committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        # resample the signal and compare it to the ground truth
        n_to_trim = 20
        sample_rate = 1000
        new_sample_rate = sample_rate

        if up_scale_factor is not None:
            new_sample_rate *= up_scale_factor

        if down_scale_factor is not None:
            new_sample_rate //= down_scale_factor

        duration = 5  # seconds
        original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

        sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
211
212
        estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate,
                                           resampling_method=resampling_method).squeeze()
jamarshon's avatar
jamarshon committed
213
214
215
216
217
218
219
220

        new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
        ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

        # trim the first/last n samples as these points have boundary effects
        ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
        estimate = estimate[..., n_to_trim:-n_to_trim]

221
        self.assertEqual(estimate, ground_truth, atol=atol, rtol=rtol)
jamarshon's avatar
jamarshon committed
222

223
224
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_downsample_accuracy(self, resampling_method):
jamarshon's avatar
jamarshon committed
225
        for i in range(1, 20):
226
            self._test_resample_waveform_accuracy(down_scale_factor=i * 2, resampling_method=resampling_method)
jamarshon's avatar
jamarshon committed
227

228
229
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_upsample_accuracy(self, resampling_method):
jamarshon's avatar
jamarshon committed
230
        for i in range(1, 20):
231
            self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
232

233
234
    @parameterized.expand([("sinc_interpolation"), ("kaiser_window")])
    def test_resample_waveform_multi_channel(self, resampling_method):
jamarshon's avatar
jamarshon committed
235
236
        num_channels = 3

237
        multi_sound = self.test_signal.repeat(num_channels, 1)  # (num_channels, 8000 smp)
jamarshon's avatar
jamarshon committed
238
239
240
241

        for i in range(num_channels):
            multi_sound[i, :] *= (i + 1) * 1.5

242
        multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test_signal_sr, self.test_signal_sr // 2,
243
                                                      resampling_method=resampling_method)
jamarshon's avatar
jamarshon committed
244
245
246

        # check that sampling is same whether using separately or in a tensor of size (c, n)
        for i in range(num_channels):
247
248
249
            single_channel = self.test_signal * (i + 1) * 1.5
            single_channel_sampled = kaldi.resample_waveform(single_channel, self.test_signal_sr,
                                                             self.test_signal_sr // 2,
250
                                                             resampling_method=resampling_method)
251
            self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)