Unverified Commit 51a67867 authored by moto's avatar moto Committed by GitHub
Browse files

Add compatibility test for `compute-fbank-feats` (#602)

* Add one fbank compatibility test

* Update util
parent 8e813596
...@@ -6,10 +6,13 @@ import subprocess ...@@ -6,10 +6,13 @@ import subprocess
import kaldi_io import kaldi_io
import torch import torch
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.compliance.kaldi
import common_utils
def _exe_exists(cmd):
return shutil.which(cmd) is not None def _not_available(cmd):
return shutil.which(cmd) is None
def _convert_args(**kwargs): def _convert_args(**kwargs):
...@@ -21,22 +24,31 @@ def _convert_args(**kwargs): ...@@ -21,22 +24,31 @@ def _convert_args(**kwargs):
return args return args
def _run_kaldi(command, input_tensor): def _run_kaldi(command, input_type, input_value):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor """Run provided Kaldi command, pass a tensor and get the resulting tensor
Assumption: Arguments:
The provided Kaldi command consumes one ark and produces one ark. input_type: str
i.e. 'ark:- ark:-' 'ark' or 'scp'
input_value:
Tensor for 'ark'
string for 'scp' (path to an audio file)
""" """
key = 'foo'
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
kaldi_io.write_mat(process.stdin, input_tensor.numpy(), key='foo') if input_type == 'ark':
kaldi_io.write_mat(process.stdin, input_value.numpy(), key=key)
elif input_type == 'scp':
process.stdin.write(f'{key} {input_value}'.encode('utf8'))
else:
raise NotImplementedError('Unexpected type')
process.stdin.close() process.stdin.close()
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo'] result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
return torch.from_numpy(result.copy()) # copy supresses some torch warning return torch.from_numpy(result.copy()) # copy supresses some torch warning
class TestFunctional: class TestFunctional:
@unittest.skipUnless(_exe_exists('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available') @unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = { kwargs = {
...@@ -49,5 +61,39 @@ class TestFunctional: ...@@ -49,5 +61,39 @@ class TestFunctional:
tensor = torch.randn(40, 10) tensor = torch.randn(40, 10)
result = F.sliding_window_cmn(tensor, **kwargs) result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-'] command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = _run_kaldi(command, tensor) kaldi_result = _run_kaldi(command, 'ark', tensor)
torch.testing.assert_allclose(result, kaldi_result)
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
def test_fbank(self):
"""fbank should be numerically compatible with compute-fbank-feats"""
kwargs = {
'blackman_coeff': 4.3926,
'dither': 0.0,
'energy_floor': 2.0617,
'frame_length': 0.5625,
'frame_shift': 0.0625,
'high_freq': 4253,
'htk_compat': True,
'low_freq': 1367,
'num_mel_bins': 5,
'preemphasis_coefficient': 0.84,
'raw_energy': False,
'remove_dc_offset': True,
'round_to_power_of_two': True,
'snip_edges': True,
'subtract_mean': False,
'use_energy': True,
'use_log_fbank': True,
'use_power': False,
'vtln_high': 2112,
'vtln_low': 1445,
'vtln_warp': 1.0000,
'window_type': 'hamming',
}
wave_file = common_utils.get_asset_path('kaldi_file.wav')
result = torchaudio.compliance.kaldi.fbank(torchaudio.load_wav(wave_file)[0], **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file)
torch.testing.assert_allclose(result, kaldi_result) torch.testing.assert_allclose(result, kaldi_result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment