Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
51a67867
Unverified
Commit
51a67867
authored
May 01, 2020
by
moto
Committed by
GitHub
May 01, 2020
Browse files
Add compatibility test for `compute-fbank-feats` (#602)
* Add one fbank compatibility test * Update util
parent
8e813596
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
9 deletions
+55
-9
test/test_kaldi_compatibility.py
test/test_kaldi_compatibility.py
+55
-9
No files found.
test/test_kaldi_compatibility.py
View file @
51a67867
...
@@ -6,10 +6,13 @@ import subprocess
...
@@ -6,10 +6,13 @@ import subprocess
import
kaldi_io
import
kaldi_io
import
torch
import
torch
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
torchaudio.compliance.kaldi
import
common_utils
def
_exe_exists
(
cmd
):
return
shutil
.
which
(
cmd
)
is
not
None
def
_not_available
(
cmd
):
return
shutil
.
which
(
cmd
)
is
None
def
_convert_args
(
**
kwargs
):
def
_convert_args
(
**
kwargs
):
...
@@ -21,22 +24,31 @@ def _convert_args(**kwargs):
...
@@ -21,22 +24,31 @@ def _convert_args(**kwargs):
return
args
return
args
def
_run_kaldi
(
command
,
input_t
ensor
):
def
_run_kaldi
(
command
,
input_t
ype
,
input_value
):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
Assumption:
Arguments:
The provided Kaldi command consumes one ark and produces one ark.
input_type: str
i.e. 'ark:- ark:-'
'ark' or 'scp'
input_value:
Tensor for 'ark'
string for 'scp' (path to an audio file)
"""
"""
key
=
'foo'
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
)
process
=
subprocess
.
Popen
(
command
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
)
kaldi_io
.
write_mat
(
process
.
stdin
,
input_tensor
.
numpy
(),
key
=
'foo'
)
if
input_type
==
'ark'
:
kaldi_io
.
write_mat
(
process
.
stdin
,
input_value
.
numpy
(),
key
=
key
)
elif
input_type
==
'scp'
:
process
.
stdin
.
write
(
f
'
{
key
}
{
input_value
}
'
.
encode
(
'utf8'
))
else
:
raise
NotImplementedError
(
'Unexpected type'
)
process
.
stdin
.
close
()
process
.
stdin
.
close
()
result
=
dict
(
kaldi_io
.
read_mat_ark
(
process
.
stdout
))[
'foo'
]
result
=
dict
(
kaldi_io
.
read_mat_ark
(
process
.
stdout
))[
'foo'
]
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
return
torch
.
from_numpy
(
result
.
copy
())
# copy supresses some torch warning
class
TestFunctional
:
class
TestFunctional
:
@
unittest
.
skip
Unless
(
_exe_exists
(
'apply-cmvn-sliding'
),
'`apply-cmvn-sliding` not available'
)
@
unittest
.
skip
If
(
_not_available
(
'apply-cmvn-sliding'
),
'`apply-cmvn-sliding` not available'
)
def
test_sliding_window_cmn
(
self
):
def
test_sliding_window_cmn
(
self
):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs
=
{
kwargs
=
{
...
@@ -49,5 +61,39 @@ class TestFunctional:
...
@@ -49,5 +61,39 @@ class TestFunctional:
tensor
=
torch
.
randn
(
40
,
10
)
tensor
=
torch
.
randn
(
40
,
10
)
result
=
F
.
sliding_window_cmn
(
tensor
,
**
kwargs
)
result
=
F
.
sliding_window_cmn
(
tensor
,
**
kwargs
)
command
=
[
'apply-cmvn-sliding'
]
+
_convert_args
(
**
kwargs
)
+
[
'ark:-'
,
'ark:-'
]
command
=
[
'apply-cmvn-sliding'
]
+
_convert_args
(
**
kwargs
)
+
[
'ark:-'
,
'ark:-'
]
kaldi_result
=
_run_kaldi
(
command
,
tensor
)
kaldi_result
=
_run_kaldi
(
command
,
'ark'
,
tensor
)
torch
.
testing
.
assert_allclose
(
result
,
kaldi_result
)
@
unittest
.
skipIf
(
_not_available
(
'compute-fbank-feats'
),
'`compute-fbank-feats` not available'
)
def
test_fbank
(
self
):
"""fbank should be numerically compatible with compute-fbank-feats"""
kwargs
=
{
'blackman_coeff'
:
4.3926
,
'dither'
:
0.0
,
'energy_floor'
:
2.0617
,
'frame_length'
:
0.5625
,
'frame_shift'
:
0.0625
,
'high_freq'
:
4253
,
'htk_compat'
:
True
,
'low_freq'
:
1367
,
'num_mel_bins'
:
5
,
'preemphasis_coefficient'
:
0.84
,
'raw_energy'
:
False
,
'remove_dc_offset'
:
True
,
'round_to_power_of_two'
:
True
,
'snip_edges'
:
True
,
'subtract_mean'
:
False
,
'use_energy'
:
True
,
'use_log_fbank'
:
True
,
'use_power'
:
False
,
'vtln_high'
:
2112
,
'vtln_low'
:
1445
,
'vtln_warp'
:
1.0000
,
'window_type'
:
'hamming'
,
}
wave_file
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
result
=
torchaudio
.
compliance
.
kaldi
.
fbank
(
torchaudio
.
load_wav
(
wave_file
)[
0
],
**
kwargs
)
command
=
[
'compute-fbank-feats'
]
+
_convert_args
(
**
kwargs
)
+
[
'scp:-'
,
'ark:-'
]
kaldi_result
=
_run_kaldi
(
command
,
'scp'
,
wave_file
)
torch
.
testing
.
assert_allclose
(
result
,
kaldi_result
)
torch
.
testing
.
assert_allclose
(
result
,
kaldi_result
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment