"vscode:/vscode.git/clone" did not exist on "40aa47b998b0203f6aca0759df9f5eeefd64fcc7"
Unverified Commit 301a6e3d authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Refactor Kaldi compatibility tests (#1359)



* Refactor Kaldi compatibility tests
Co-authored-by: default avatarJeff Hwang <jeffhwang@fb.com>
parent 64551a69
...@@ -54,7 +54,7 @@ The following is an overview of the tests and related modules for `torchaudio`. ...@@ -54,7 +54,7 @@ The following is an overview of the tests and related modules for `torchaudio`.
Test suite for numerical compatibility against librosa. Test suite for numerical compatibility against librosa.
- [SoX compatibility test](./transforms/sox_compatibility_test.py) - [SoX compatibility test](./transforms/sox_compatibility_test.py)
Test suite for numerical compatibility against SoX. Test suite for numerical compatibility against SoX.
- [Kaldi compatibility test](./kaldi_compatibility_test.py) - [Kaldi compatibility test](./transforms/kaldi_compatibility_impl.py)
Test suite for numerical compatibility against Kaldi. Test suite for numerical compatibility against Kaldi.
#### Result consistency with PyTorch framework #### Result consistency with PyTorch framework
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .kaldi_compatibility_test_impl import KaldiCPUOnly from .kaldi_compatibility_test_impl import Kaldi, KaldiCPUOnly
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase): class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .kaldi_compatibility_test_impl import Kaldi
@skipIfNoCuda
class TestKaldiFloat32(Kaldi, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestKaldiFloat64(Kaldi, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
from parameterized import parameterized from parameterized import parameterized
import torch
import torchaudio.functional as F import torchaudio.functional as F
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -15,6 +16,28 @@ from torchaudio_unittest.common_utils.kaldi_utils import ( ...@@ -15,6 +16,28 @@ from torchaudio_unittest.common_utils.kaldi_utils import (
) )
class Kaldi(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
}
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result)
class KaldiCPUOnly(TempDirMixin, TestBaseMixin): class KaldiCPUOnly(TempDirMixin, TestBaseMixin):
def assert_equal(self, output, *, expected, rtol=None, atol=None): def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
......
"""Test suites for checking numerical compatibility against Kaldi""" """Test suites for checking numerical compatibility against Kaldi"""
import torch
import torchaudio.functional as F
import torchaudio.compliance.kaldi import torchaudio.compliance.kaldi
from parameterized import parameterized from parameterized import parameterized
...@@ -23,22 +21,6 @@ class Kaldi(TempDirMixin, TestBaseMixin): ...@@ -23,22 +21,6 @@ class Kaldi(TempDirMixin, TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
@skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = {
'cmn_window': 600,
'min_cmn_window': 100,
'center': False,
'norm_vars': False,
}
tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
result = F.sliding_window_cmn(tensor, **kwargs)
command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
kaldi_result = run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(load_params('kaldi_test_fbank_args.json')) @parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@skipIfNoExec('compute-fbank-feats') @skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs): def test_fbank(self, kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment