Unverified Commit 8a03087e authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Migrate kaldi fbank (#672)



* Migrate fbank tests

* Update CCI job environment

* Remove invalid test cases
Signed-off-by: default avatarBhargav Kathivarapu <bhargavkathivarapu31@gmail.com>
parent b0367251
...@@ -13,3 +13,4 @@ dependencies: ...@@ -13,3 +13,4 @@ dependencies:
- clang-format - clang-format
- kaldi-io - kaldi-io
- scipy - scipy
- parameterized
...@@ -14,3 +14,4 @@ dependencies: ...@@ -14,3 +14,4 @@ dependencies:
- PySoundFile - PySoundFile
- librosa - librosa
- future - future
- parameterized
This diff is collapsed.
"""Test suites for checking numerical compatibility against Kaldi""" """Test suites for checking numerical compatibility against Kaldi"""
import json
import shutil import shutil
import unittest import unittest
import subprocess import subprocess
...@@ -9,6 +10,7 @@ import torchaudio.functional as F ...@@ -9,6 +10,7 @@ import torchaudio.functional as F
import torchaudio.compliance.kaldi import torchaudio.compliance.kaldi
from . import common_utils from . import common_utils
from parameterized import parameterized, param
def _not_available(cmd): def _not_available(cmd):
...@@ -47,6 +49,11 @@ def _run_kaldi(command, input_type, input_value): ...@@ -47,6 +49,11 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning return torch.from_numpy(result.copy()) # copy supresses some torch warning
def _load_params(path):
with open(path, 'r') as file:
return [param(json.loads(line)) for line in file]
class Kaldi(common_utils.TestBaseMixin): class Kaldi(common_utils.TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None): def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
...@@ -68,34 +75,10 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -68,34 +75,10 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result = _run_kaldi(command, 'ark', tensor) kaldi_result = _run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json')))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available') @unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
def test_fbank(self): def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats""" """fbank should be numerically compatible with compute-fbank-feats"""
kwargs = {
'blackman_coeff': 4.3926,
'dither': 0.0,
'energy_floor': 2.0617,
'frame_length': 0.5625,
'frame_shift': 0.0625,
'high_freq': 4253,
'htk_compat': True,
'low_freq': 1367,
'num_mel_bins': 5,
'preemphasis_coefficient': 0.84,
'raw_energy': False,
'remove_dc_offset': True,
'round_to_power_of_two': True,
'snip_edges': True,
'subtract_mean': False,
'use_energy': True,
'use_log_fbank': True,
'use_power': False,
'vtln_high': 2112,
'vtln_low': 1445,
'vtln_warp': 1.0000,
'window_type': 'hamming',
}
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = common_utils.get_asset_path('kaldi_file.wav')
waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device) waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment