kaldi_compatibility_impl.py 3.33 KB
Newer Older
1
2
3
"""Test suites for checking numerical compatibility against Kaldi"""
import torch
import torchaudio.functional as F
4
import torchaudio.compliance.kaldi
moto's avatar
moto committed
5
from parameterized import parameterized
6

7
from torchaudio_unittest.common_utils import (
8
    TestBaseMixin,
moto's avatar
moto committed
9
    TempDirMixin,
10
11
12
    load_params,
    skipIfNoExec,
    get_asset_path,
moto's avatar
moto committed
13
14
15
16
17
    load_wav,
)
from torchaudio_unittest.common_utils.kaldi_utils import (
    convert_args,
    run_kaldi,
18
19
)

20

moto's avatar
moto committed
21
class Kaldi(TempDirMixin, TestBaseMixin):
22
23
24
25
    def assert_equal(self, output, *, expected, rtol=None, atol=None):
        expected = expected.to(dtype=self.dtype, device=self.device)
        self.assertEqual(output, expected, rtol=rtol, atol=atol)

26
    @skipIfNoExec('apply-cmvn-sliding')
27
28
29
30
31
32
33
34
35
    def test_sliding_window_cmn(self):
        """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
        kwargs = {
            'cmn_window': 600,
            'min_cmn_window': 100,
            'center': False,
            'norm_vars': False,
        }

moto's avatar
moto committed
36
        tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
37
        result = F.sliding_window_cmn(tensor, **kwargs)
moto's avatar
moto committed
38
39
        command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-']
        kaldi_result = run_kaldi(command, 'ark', tensor)
40
        self.assert_equal(result, expected=kaldi_result)
41

moto's avatar
moto committed
42
    @parameterized.expand(load_params('kaldi_test_fbank_args.json'))
43
    @skipIfNoExec('compute-fbank-feats')
44
    def test_fbank(self, kwargs):
45
        """fbank should be numerically compatible with compute-fbank-feats"""
46
47
        wave_file = get_asset_path('kaldi_file.wav')
        waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
moto's avatar
moto committed
48
        result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
moto's avatar
moto committed
49
50
        command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
        kaldi_result = run_kaldi(command, 'scp', wave_file)
51
        self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
52

moto's avatar
moto committed
53
    @parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
54
    @skipIfNoExec('compute-spectrogram-feats')
55
56
    def test_spectrogram(self, kwargs):
        """spectrogram should be numerically compatible with compute-spectrogram-feats"""
57
58
        wave_file = get_asset_path('kaldi_file.wav')
        waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
59
        result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
moto's avatar
moto committed
60
61
        command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
        kaldi_result = run_kaldi(command, 'scp', wave_file)
62
63
        self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

moto's avatar
moto committed
64
    @parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
65
    @skipIfNoExec('compute-mfcc-feats')
66
67
    def test_mfcc(self, kwargs):
        """mfcc should be numerically compatible with compute-mfcc-feats"""
68
69
        wave_file = get_asset_path('kaldi_file.wav')
        waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device)
70
        result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
moto's avatar
moto committed
71
72
        command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-']
        kaldi_result = run_kaldi(command, 'scp', wave_file)
73
        self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)