Unverified Commit faefe689 authored by moto's avatar moto Committed by GitHub
Browse files

Adopt PyTorch's test utility on kaldi compatibility test (#650)

* Adopt PyTorch's test utility on kaldi compatibility test

 - Adopt PyTorch's test utility on kaldi compatibility test
 - Separate CPU test and GPU test

* Fix device/dtype conversion
parent bc6b7f97
import common_utils
from kaldi_compatibility_impl import Kaldi
common_utils.define_test_suites(globals(), [Kaldi], devices=['cpu'])
import common_utils
from kaldi_compatibility_impl import Kaldi
common_utils.define_test_suites(globals(), [Kaldi], devices=['cuda'])
...@@ -48,6 +48,10 @@ def _run_kaldi(command, input_type, input_value): ...@@ -48,6 +48,10 @@ def _run_kaldi(command, input_type, input_value):
class Kaldi(common_utils.TestBaseMixin): class Kaldi(common_utils.TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available') @unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
...@@ -62,7 +66,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -62,7 +66,7 @@ class Kaldi(common_utils.TestBaseMixin):
result = F.sliding_window_cmn(tensor, **kwargs) result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-'] command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'ark', tensor) kaldi_result = _run_kaldi(command, 'ark', tensor)
torch.testing.assert_allclose(result.cpu(), kaldi_result.to(self.dtype)) self.assert_equal(result, expected=kaldi_result)
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available') @unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
def test_fbank(self): def test_fbank(self):
...@@ -97,7 +101,4 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -97,7 +101,4 @@ class Kaldi(common_utils.TestBaseMixin):
result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
torch.testing.assert_allclose(result.cpu(), kaldi_result.to(dtype=self.dtype), rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
common_utils.define_test_suites(globals(), [Kaldi])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment