"vscode:/vscode.git/clone" did not exist on "4b15fa00f0829bf80c1ed92e722bee9f122ec654"
kaldi_compatibility_impl.py 4.8 KB
Newer Older
1
"""Test suites for checking numerical compatibility against Kaldi"""
2
import json
3
4
5
6
7
8
9
import shutil
import unittest
import subprocess

import kaldi_io
import torch
import torchaudio.functional as F
10
import torchaudio.compliance.kaldi
11

12
from . import common_utils
13
from parameterized import parameterized, param
14

15
16
17

def _not_available(cmd):
    return shutil.which(cmd) is None
18
19
20
21
22
23
24
25
26
27
28


def _convert_args(**kwargs):
    args = []
    for key, value in kwargs.items():
        key = '--' + key.replace('_', '-')
        value = str(value).lower() if value in [True, False] else str(value)
        args.append('%s=%s' % (key, value))
    return args


29
def _run_kaldi(command, input_type, input_value):
30
31
    """Run provided Kaldi command, pass a tensor and get the resulting tensor

32
33
34
35
36
37
    Arguments:
        input_type: str
            'ark' or 'scp'
        input_value:
            Tensor for 'ark'
            string for 'scp' (path to an audio file)
38
    """
39
    key = 'foo'
40
    process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
41
    if input_type == 'ark':
moto's avatar
moto committed
42
        kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
43
44
45
46
    elif input_type == 'scp':
        process.stdin.write(f'{key} {input_value}'.encode('utf8'))
    else:
        raise NotImplementedError('Unexpected type')
47
48
49
50
51
    process.stdin.close()
    result = dict(kaldi_io.read_mat_ark(process.stdout))['foo']
    return torch.from_numpy(result.copy())  # copy supresses some torch warning


52
53
54
55
56
def _load_params(path):
    with open(path, 'r') as file:
        return [param(json.loads(line)) for line in file]


moto's avatar
moto committed
57
class Kaldi(common_utils.TestBaseMixin):
58
59
60
61
    def assert_equal(self, output, *, expected, rtol=None, atol=None):
        expected = expected.to(dtype=self.dtype, device=self.device)
        self.assertEqual(output, expected, rtol=rtol, atol=atol)

62
    @unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available')
63
64
65
66
67
68
69
70
71
    def test_sliding_window_cmn(self):
        """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
        kwargs = {
            'cmn_window': 600,
            'min_cmn_window': 100,
            'center': False,
            'norm_vars': False,
        }

moto's avatar
moto committed
72
        tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device)
73
74
        result = F.sliding_window_cmn(tensor, **kwargs)
        command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-']
75
        kaldi_result = _run_kaldi(command, 'ark', tensor)
76
        self.assert_equal(result, expected=kaldi_result)
77

78
    @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json')))
79
    @unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available')
80
    def test_fbank(self, kwargs):
81
82
        """fbank should be numerically compatible with compute-fbank-feats"""
        wave_file = common_utils.get_asset_path('kaldi_file.wav')
moto's avatar
moto committed
83
84
        waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
        result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs)
85
86
        command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
        kaldi_result = _run_kaldi(command, 'scp', wave_file)
87
        self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
88

89
90
91
92
93
94
95
96
97
98
99
    @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json')))
    @unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available')
    def test_spectrogram(self, kwargs):
        """spectrogram should be numerically compatible with compute-spectrogram-feats"""
        wave_file = common_utils.get_asset_path('kaldi_file.wav')
        waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
        result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs)
        command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
        kaldi_result = _run_kaldi(command, 'scp', wave_file)
        self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)

100
101
102
103
104
105
106
107
108
109
    @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json')))
    @unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available')
    def test_mfcc(self, kwargs):
        """mfcc should be numerically compatible with compute-mfcc-feats"""
        wave_file = common_utils.get_asset_path('kaldi_file.wav')
        waveform = torchaudio.load_wav(wave_file)[0].to(dtype=self.dtype, device=self.device)
        result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
        command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-']
        kaldi_result = _run_kaldi(command, 'scp', wave_file)
        self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)