import torch import torchaudio.compliance.kaldi as kaldi from torchaudio_unittest import common_utils def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): # just a copy of ExtractWindow from feature-window.cc in python def first_sample_of_frame(frame, window_size, window_shift, snip_edges): if snip_edges: return frame * window_shift else: midpoint_of_frame = frame * window_shift + window_shift // 2 beginning_of_frame = midpoint_of_frame - window_size // 2 return beginning_of_frame sample_offset = 0 num_samples = sample_offset + wave.size(0) start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges) end_sample = start_sample + frame_length if snip_edges: assert start_sample >= sample_offset and end_sample <= num_samples else: assert sample_offset == 0 or start_sample >= sample_offset wave_start = start_sample - sample_offset wave_end = wave_start + frame_length if wave_start >= 0 and wave_end <= wave.size(0): window[f, :] = wave[wave_start : (wave_start + frame_length)] else: wave_dim = wave.size(0) for s in range(frame_length): s_in_wave = s + wave_start while s_in_wave < 0 or s_in_wave >= wave_dim: if s_in_wave < 0: s_in_wave = -s_in_wave - 1 else: s_in_wave = 2 * wave_dim - 1 - s_in_wave window[f, s] = wave[s_in_wave] class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase): def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges): waveform = torch.arange(num_samples).float() output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges) # from NumFrames in feature-window.cc n = window_size if snip_edges: m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift else: m = (num_samples + (window_shift // 2)) // window_shift self.assertTrue(output.dim() == 2) self.assertTrue(output.shape[0] == m and output.shape[1] == n) window = torch.empty((m, window_size)) for r in range(m): extract_window(window, waveform, r, window_size, window_shift, snip_edges) self.assertEqual(window, output) def test_get_strided(self): # generate any combination where 0 < window_size <= num_samples and # 0 < window_shift. for num_samples in range(1, 20): for window_size in range(1, num_samples + 1): for window_shift in range(1, 2 * num_samples + 1): for snip_edges in range(0, 2): self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges) def test_mfcc_empty(self): # Passing in an empty tensor should result in an error self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))