Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
15a7f78c
"src/vscode:/vscode.git/clone" did not exist on "8bf046b7fb8aa41691cb42596313038868b251ac"
Unverified
Commit
15a7f78c
authored
Jun 03, 2021
by
Caroline Chen
Committed by
GitHub
Jun 03, 2021
Browse files
Migrate resample tests from kaldi to functional (#1520)
parent
68823423
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
87 deletions
+75
-87
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+0
-87
test/torchaudio_unittest/functional/batch_consistency_test.py
.../torchaudio_unittest/functional/batch_consistency_test.py
+11
-0
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+64
-0
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
15a7f78c
...
@@ -7,7 +7,6 @@ import torchaudio.compliance.kaldi as kaldi
...
@@ -7,7 +7,6 @@ import torchaudio.compliance.kaldi as kaldi
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
from
.compliance
import
utils
as
compliance_utils
from
.compliance
import
utils
as
compliance_utils
from
parameterized
import
parameterized
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
def
extract_window
(
window
,
wave
,
f
,
frame_length
,
frame_shift
,
snip_edges
):
...
@@ -53,15 +52,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -53,15 +52,6 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
test_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
test_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file.wav'
)
test_filepaths
=
{
prefix
:
[]
for
prefix
in
compliance_utils
.
TEST_PREFIX
}
test_filepaths
=
{
prefix
:
[]
for
prefix
in
compliance_utils
.
TEST_PREFIX
}
def
setUp
(
self
):
super
().
setUp
()
# test signal for testing resampling
self
.
test_signal_sr
=
16000
self
.
test_signal
=
common_utils
.
get_whitenoise
(
sample_rate
=
self
.
test_signal_sr
,
duration
=
0.5
,
)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for
f
in
os
.
listdir
(
kaldi_output_dir
):
for
f
in
os
.
listdir
(
kaldi_output_dir
):
dash_idx
=
f
.
find
(
'-'
)
dash_idx
=
f
.
find
(
'-'
)
...
@@ -172,80 +162,3 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -172,80 +162,3 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
test_mfcc_empty
(
self
):
def
test_mfcc_empty
(
self
):
# Passing in an empty tensor should result in an error
# Passing in an empty tensor should result in an error
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
*
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
*
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
//
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
))
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
# resample the signal and compare it to the ground truth
n_to_trim
=
20
sample_rate
=
1000
new_sample_rate
=
sample_rate
if
up_scale_factor
is
not
None
:
new_sample_rate
*=
up_scale_factor
if
down_scale_factor
is
not
None
:
new_sample_rate
//=
down_scale_factor
duration
=
5
# seconds
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
estimate
=
kaldi
.
resample_waveform
(
sound
,
sample_rate
,
new_sample_rate
,
resampling_method
=
resampling_method
).
squeeze
()
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
# trim the first/last n samples as these points have boundary effects
ground_truth
=
ground_truth
[...,
n_to_trim
:
-
n_to_trim
]
estimate
=
estimate
[...,
n_to_trim
:
-
n_to_trim
]
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
,
resampling_method
=
resampling_method
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
):
for
i
in
range
(
1
,
20
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
num_channels
=
3
multi_sound
=
self
.
test_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
for
i
in
range
(
num_channels
):
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for
i
in
range
(
num_channels
):
single_channel
=
self
.
test_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
test/torchaudio_unittest/functional/batch_consistency_test.py
View file @
15a7f78c
...
@@ -197,6 +197,17 @@ class TestFunctional(common_utils.TorchaudioTestCase):
...
@@ -197,6 +197,17 @@ class TestFunctional(common_utils.TorchaudioTestCase):
F
.
sliding_window_cmn
,
spectrogram
,
center
=
center
,
F
.
sliding_window_cmn
,
spectrogram
,
center
=
center
,
norm_vars
=
norm_vars
)
norm_vars
=
norm_vars
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform
(
self
,
resampling_method
):
num_channels
=
3
sr
=
16000
new_sr
=
sr
//
2
multi_sound
=
common_utils
.
get_whitenoise
(
sample_rate
=
sr
,
n_channels
=
num_channels
,
duration
=
0.5
,)
self
.
assert_batch_consistency
(
F
.
resample
,
multi_sound
,
orig_freq
=
sr
,
new_freq
=
new_sr
,
resampling_method
=
resampling_method
,
rtol
=
1e-4
,
atol
=
1e-7
)
@
common_utils
.
skipIfNoKaldi
@
common_utils
.
skipIfNoKaldi
def
test_compute_kaldi_pitch
(
self
):
def
test_compute_kaldi_pitch
(
self
):
sample_rate
=
44100
sample_rate
=
44100
...
...
test/torchaudio_unittest/functional/functional_impl.py
View file @
15a7f78c
...
@@ -13,6 +13,35 @@ from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested
...
@@ -13,6 +13,35 @@ from torchaudio_unittest.common_utils import TestBaseMixin, get_sinusoid, nested
class
Functional
(
TestBaseMixin
):
class
Functional
(
TestBaseMixin
):
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
# resample the signal and compare it to the ground truth
n_to_trim
=
20
sample_rate
=
1000
new_sample_rate
=
sample_rate
if
up_scale_factor
is
not
None
:
new_sample_rate
*=
up_scale_factor
if
down_scale_factor
is
not
None
:
new_sample_rate
//=
down_scale_factor
duration
=
5
# seconds
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
sound
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
original_timestamps
).
unsqueeze
(
0
)
estimate
=
F
.
resample
(
sound
,
sample_rate
,
new_sample_rate
,
resampling_method
=
resampling_method
).
squeeze
()
new_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
new_sample_rate
)[:
estimate
.
size
(
0
)]
ground_truth
=
123
*
torch
.
cos
(
2
*
math
.
pi
*
3
*
new_timestamps
)
# trim the first/last n samples as these points have boundary effects
ground_truth
=
ground_truth
[...,
n_to_trim
:
-
n_to_trim
]
estimate
=
estimate
[...,
n_to_trim
:
-
n_to_trim
]
self
.
assertEqual
(
estimate
,
ground_truth
,
atol
=
atol
,
rtol
=
rtol
)
def
test_lfilter_simple
(
self
):
def
test_lfilter_simple
(
self
):
"""
"""
Create a very basic signal,
Create a very basic signal,
...
@@ -269,6 +298,41 @@ class Functional(TestBaseMixin):
...
@@ -269,6 +298,41 @@ class Functional(TestBaseMixin):
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
self
.
assertEqual
(
waveform
,
resampled
)
self
.
assertEqual
(
waveform
,
resampled
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
sr
=
16000
waveform
=
get_whitenoise
(
sample_rate
=
sr
,
duration
=
0.5
,)
upsampled
=
F
.
resample
(
waveform
,
sr
,
sr
*
2
,
resampling_method
=
resampling_method
)
assert
upsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
*
2
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
sr
=
16000
waveform
=
get_whitenoise
(
sample_rate
=
sr
,
duration
=
0.5
,)
downsampled
=
F
.
resample
(
waveform
,
sr
,
sr
//
2
,
resampling_method
=
resampling_method
)
assert
downsampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
//
2
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
sr
=
16000
waveform
=
get_whitenoise
(
sample_rate
=
sr
,
duration
=
0.5
,)
resampled
=
F
.
resample
(
waveform
,
sr
,
sr
,
resampling_method
=
resampling_method
)
assert
resampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
list
(
range
(
1
,
20
)),
)))
def
test_resample_waveform_downsample_accuracy
(
self
,
resampling_method
,
i
):
self
.
_test_resample_waveform_accuracy
(
down_scale_factor
=
i
*
2
,
resampling_method
=
resampling_method
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
list
(
range
(
1
,
20
)),
)))
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
,
i
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
def
test_resample_no_warning
(
self
):
def
test_resample_no_warning
(
self
):
sample_rate
=
44100
sample_rate
=
44100
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment