Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2897f366
Unverified
Commit
2897f366
authored
Mar 15, 2021
by
discort
Committed by
GitHub
Mar 15, 2021
Browse files
Replace torch.assert_allclose with assertEqual (#1387)
parent
f2b75427
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+8
-8
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
2897f366
...
@@ -45,9 +45,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
...
@@ -45,9 +45,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
window
[
f
,
s
]
=
wave
[
s_in_wave
]
window
[
f
,
s
]
=
wave
[
s_in_wave
]
@
common_utils
.
skipIfNoSox
Backend
@
common_utils
.
skipIfNoSox
class
Test_Kaldi
(
common_utils
.
TempDirMixin
,
common_utils
.
TorchaudioTestCase
):
class
Test_Kaldi
(
common_utils
.
TempDirMixin
,
common_utils
.
TorchaudioTestCase
):
backend
=
'sox'
backend
=
'sox
_io
'
kaldi_output_dir
=
common_utils
.
get_asset_path
(
'kaldi'
)
kaldi_output_dir
=
common_utils
.
get_asset_path
(
'kaldi'
)
test_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
test_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
...
@@ -91,7 +91,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -91,7 +91,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
for
r
in
range
(
m
):
for
r
in
range
(
m
):
extract_window
(
window
,
waveform
,
r
,
window_size
,
window_shift
,
snip_edges
)
extract_window
(
window
,
waveform
,
r
,
window_size
,
window_shift
,
snip_edges
)
torch
.
testing
.
assert_allclose
(
window
,
output
)
self
.
assertEqual
(
window
,
output
)
def
test_get_strided
(
self
):
def
test_get_strided
(
self
):
# generate any combination where 0 < window_size <= num_samples and
# generate any combination where 0 < window_size <= num_samples and
...
@@ -116,7 +116,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -116,7 +116,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
sound
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
False
)
sound
,
sample_rate
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
False
)
print
(
y
>>
16
)
print
(
y
>>
16
)
self
.
assertTrue
(
sample_rate
==
sr
)
self
.
assertTrue
(
sample_rate
==
sr
)
torch
.
testing
.
assert_allclose
(
y
,
sound
)
self
.
assertEqual
(
y
,
sound
)
def
_print_diagnostic
(
self
,
output
,
expect_output
):
def
_print_diagnostic
(
self
,
output
,
expect_output
):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
...
@@ -170,7 +170,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -170,7 +170,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
output
=
get_output_fn
(
sound
,
args
)
output
=
get_output_fn
(
sound
,
args
)
self
.
_print_diagnostic
(
output
,
kaldi_output
)
self
.
_print_diagnostic
(
output
,
kaldi_output
)
torch
.
testing
.
assert_allclose
(
output
,
kaldi_output
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
output
,
kaldi_output
,
atol
=
atol
,
rtol
=
rtol
)
def
test_mfcc_empty
(
self
):
def
test_mfcc_empty
(
self
):
# Passing in an empty tensor should result in an error
# Passing in an empty tensor should result in an error
...
@@ -178,7 +178,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -178,7 +178,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
test_resample_waveform
(
self
):
def
test_resample_waveform
(
self
):
def
get_output_fn
(
sound
,
args
):
def
get_output_fn
(
sound
,
args
):
output
=
kaldi
.
resample_waveform
(
sound
,
args
[
1
],
args
[
2
])
output
=
kaldi
.
resample_waveform
(
sound
.
to
(
torch
.
float32
)
,
args
[
1
],
args
[
2
])
return
output
return
output
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
...
@@ -221,7 +221,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -221,7 +221,7 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
ground_truth
=
ground_truth
[...,
n_to_trim
:
-
n_to_trim
]
ground_truth
=
ground_truth
[...,
n_to_trim
:
-
n_to_trim
]
estimate
=
estimate
[...,
n_to_trim
:
-
n_to_trim
]
estimate
=
estimate
[...,
n_to_trim
:
-
n_to_trim
]
torch
.
testing
.
assert_allclose
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
def
test_resample_waveform_downsample_accuracy
(
self
):
def
test_resample_waveform_downsample_accuracy
(
self
):
for
i
in
range
(
1
,
20
):
for
i
in
range
(
1
,
20
):
...
@@ -246,4 +246,4 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -246,4 +246,4 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
single_channel
=
self
.
test1_signal
*
(
i
+
1
)
*
1.5
single_channel
=
self
.
test1_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test1_signal_sr
,
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test1_signal_sr
,
self
.
test1_signal_sr
//
2
)
self
.
test1_signal_sr
//
2
)
torch
.
testing
.
assert_allclose
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment