Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
15a7f78c
Unverified
Commit
15a7f78c
authored
Jun 03, 2021
by
Caroline Chen
Committed by
GitHub
Jun 03, 2021
Browse files
Migrate resample tests from kaldi to functional (#1520)
parent
68823423
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
87 deletions
+75
-87
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+0
-87
test/torchaudio_unittest/functional/batch_consistency_test.py
.../torchaudio_unittest/functional/batch_consistency_test.py
+11
-0
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+64
-0
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
15a7f78c
...
@@ -7,7 +7,6 @@ import torchaudio.compliance.kaldi as kaldi
...
@@ -7,7 +7,6 @@ import torchaudio.compliance.kaldi as kaldi
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
from
.compliance
import
utils
as
compliance_utils
from
.compliance
import
utils
as
compliance_utils
from
parameterized
import
parameterized
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
...
@@ -53,15 +52,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -53,15 +52,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
test_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
test_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
test_filepaths
=
{
prefix
:
[]
for
prefix
in
compliance_utils
.
TEST_PREFIX
}
test_filepaths
=
{
prefix
:
[]
for
prefix
in
compliance_utils
.
TEST_PREFIX
}
def
setUp
(
self
):
super
().
setUp
()
# test signal for testing resampling
self
.
test_signal_sr
=
16000
self
.
test_signal
=
common_utils
.
get_whitenoise
(
sample_rate
=
self
.
test_signal_sr
,
duration
=
0.5
,
)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for
f
in
os
.
listdir
(
kaldi_output_dir
):
for
f
in
os
.
listdir
(
kaldi_output_dir
):
dash_idx
=
f
.
find
(
'-'
)
dash_idx
=
f
.
find
(
'-'
)
...
@@ -172,80 +162,3 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -172,80 +162,3 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
test_mfcc_empty
(
self
):
def
test_mfcc_empty
(
self
):
# Passing in an empty tensor should result in an error
# Passing in an empty tensor should result in an error
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
*
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
*
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
//
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
))
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
# resample the signal and compare it to the ground truth
n_to_trim
=
20
sample_rate
=
1000
new_sample_rate
=
sample_rate
if
up_scale_factor
is
not
None
:
new_sample_rate
*=
up_scale_factor
if
down_scale_factor
is
not
None
:
new_sample_rate
//=
down_scale_factor
duration
=
5
# seconds
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
estimate
=
kaldi
.
resample_waveform
(
sound
,
sample_rate
,
new_sample_rate
,
resampling_method
=
resampling_method
).
squeeze
()
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
# trim the first/last n samples as these points have boundary effects
ground_truth
=
ground_truth
[...,
n_to_trim
:
-
n_to_trim
]
estimate
=
estimate
[...,
n_to_trim
:
-
n_to_trim
]
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
,
resampling_method
=
resampling_method
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
num_channels
=
3
multi_sound
=
self
.
test_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
for
i
in
range
(
num_channels
):
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for
i
in
range
(
num_channels
):
single_channel
=
self
.
test_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
test/torchaudio_unittest/functional/batch_consistency_test.py
View file @
15a7f78c
...
@@ -197,6 +197,17 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -197,6 +197,17 @@ class TestFunctional(common_utils.TorchaudioTestCase):
F
.
sliding_window_cmn
,
spectrogram
,
center
=
center
,
F
.
sliding_window_cmn
,
spectrogram
,
center
=
center
,
norm_vars
=
norm_vars
)
norm_vars
=
norm_vars
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform
(
self
,
resampling_method
):
num_channels
=
3
sr
=
16000
new_sr
=
sr
//
2
multi_sound
=
common_utils
.
get_whitenoise
(
sample_rate
=
sr
,
n_channels
=
num_channels
,
duration
=
0.5
,)
self
.
assert_batch_consistency
(
F
.
resample
,
multi_sound
,
orig_freq
=
sr
,
new_freq
=
new_sr
,
resampling_method
=
resampling_method
,
rtol
=
1e-4
,
atol
=
1e-7
)
@
common_utils
.
skipIfNoKaldi
@
common_utils
.
skipIfNoKaldi
def
test_compute_kaldi_pitch
(
self
):
def
test_compute_kaldi_pitch
(
self
):
sample_rate
=
44100
sample_rate
=
44100
...
...
test/torchaudio_unittest/functional/functional_impl.py
View file @
15a7f78c
...
@@ -13,6 +13,35 @@ from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested
...
@@ -13,6 +13,35 @@ from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested
class
Functional
(
TestBaseMixin
):
class
Functional
(
TestBaseMixin
):
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
# resample the signal and compare it to the ground truth
n_to_trim
=
20
sample_rate
=
1000
new_sample_rate
=
sample_rate
if
up_scale_factor
is
not
None
:
new_sample_rate
*=
up_scale_factor
if
down_scale_factor
is
not
None
:
new_sample_rate
//=
down_scale_factor
duration
=
5
# seconds
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
estimate
=
F
.
resample
(
sound
,
sample_rate
,
new_sample_rate
,
resampling_method
=
resampling_method
).
squeeze
()
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
# trim the first/last n samples as these points have boundary effects
ground_truth
=
ground_truth
[...,
n_to_trim
:
-
n_to_trim
]
estimate
=
estimate
[...,
n_to_trim
:
-
n_to_trim
]
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
def
test_lfilter_simple
(
self
):
def
test_lfilter_simple
(
self
):
"""
"""
Create a very basic signal,
Create a very basic signal,
...
@@ -269,6 +298,41 @@ class Functional(TestBaseMixin):
...
@@ -269,6 +298,41 @@ class Functional(TestBaseMixin):
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
self
.
assertEqual
(
waveform
,
resampled
)
self
.
assertEqual
(
waveform
,
resampled
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
sr
=
16000
waveform
=
get_whitenoise
(
sample_rate
=
sr
,
duration
=
0.5
,)
upsampled
=
F
.
resample
(
waveform
,
sr
,
sr
*
2
,
resampling_method
=
resampling_method
)
assert
upsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
*
2
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
sr
=
16000
waveform
=
get_whitenoise
(
sample_rate
=
sr
,
duration
=
0.5
,)
downsampled
=
F
.
resample
(
waveform
,
sr
,
sr
//
2
,
resampling_method
=
resampling_method
)
assert
downsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
//
2
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
sr
=
16000
waveform
=
get_whitenoise
(
sample_rate
=
sr
,
duration
=
0.5
,)
resampled
=
F
.
resample
(
waveform
,
sr
,
sr
,
resampling_method
=
resampling_method
)
assert
resampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
list
(
range
(
1
,
20
)),
)))
def
test_resample_waveform_downsample_accuracy
(
self
,
resampling_method
,
i
):
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
,
resampling_method
=
resampling_method
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
list
(
range
(
1
,
20
)),
)))
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
,
i
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
def
test_resample_no_warning
(
self
):
def
test_resample_no_warning
(
self
):
sample_rate
=
44100
sample_rate
=
44100
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment