Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
fad19fab
Unverified
Commit
fad19fab
authored
Jun 01, 2021
by
Caroline Chen
Committed by
GitHub
Jun 01, 2021
Browse files
Ensure resampling identity is unchanged (#1537)
parent
f1a0b605
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
29 deletions
+48
-29
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+15
-25
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+10
-0
test/torchaudio_unittest/transforms/transforms_test_impl.py
test/torchaudio_unittest/transforms/transforms_test_impl.py
+13
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+3
-0
torchaudio/transforms.py
torchaudio/transforms.py
+7
-4
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
fad19fab
...
...
@@ -56,15 +56,12 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
setUp
(
self
):
super
().
setUp
()
#
1.
test signal for testing resampling
self
.
test
1
_signal_sr
=
16000
self
.
test
1
_signal
=
common_utils
.
get_whitenoise
(
sample_rate
=
self
.
test
1
_signal_sr
,
duration
=
0.5
,
# test signal for testing resampling
self
.
test_signal_sr
=
16000
self
.
test_signal
=
common_utils
.
get_whitenoise
(
sample_rate
=
self
.
test_signal_sr
,
duration
=
0.5
,
)
# 2. test audio file corresponding to saved kaldi ark files
self
.
test2_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file_8000.wav'
)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for
f
in
os
.
listdir
(
kaldi_output_dir
):
dash_idx
=
f
.
find
(
'-'
)
...
...
@@ -176,30 +173,23 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
# Passing in an empty tensor should result in an error
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
def
test_resample_waveform
(
self
):
def
get_output_fn
(
sound
,
args
):
output
=
kaldi
.
resample_waveform
(
sound
.
to
(
torch
.
float32
),
args
[
1
],
args
[
2
])
return
output
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test
1
_signal
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
*
2
,
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
*
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test
1
_signal
.
size
(
-
1
)
*
2
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
*
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test
1
_signal
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
//
2
,
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test
1
_signal
.
size
(
-
1
)
//
2
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
//
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test
1
_signal
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
,
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
,
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test
1
_signal
.
size
(
-
1
))
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
))
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
...
...
@@ -244,18 +234,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
num_channels
=
3
multi_sound
=
self
.
test
1
_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
multi_sound
=
self
.
test_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
for
i
in
range
(
num_channels
):
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
//
2
,
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for
i
in
range
(
num_channels
):
single_channel
=
self
.
test
1
_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
//
2
,
single_channel
=
self
.
test_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
fad19fab
...
...
@@ -259,6 +259,16 @@ class Functional(TestBaseMixin):
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
[
16000
,
44100
],
)))
def
test_resample_identity
(
self
,
resampling_method
,
sample_rate
):
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
1
)
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
self
.
assertEqual
(
waveform
,
resampled
)
def
test_resample_no_warning
(
self
):
sample_rate
=
44100
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
...
...
test/torchaudio_unittest/transforms/transforms_test_impl.py
View file @
fad19fab
import
itertools
import
warnings
import
torch
...
...
@@ -8,6 +9,7 @@ from torchaudio_unittest.common_utils import (
get_whitenoise
,
get_spectrogram
,
)
from
parameterized
import
parameterized
def
_get_ratio
(
mat
):
...
...
@@ -77,3 +79,14 @@ class TransformsTestBase(TestBaseMixin):
warnings
.
simplefilter
(
"always"
)
T
.
MelScale
(
n_mels
=
64
,
sample_rate
=
8000
,
n_stft
=
201
)
assert
len
(
caught_warnings
)
==
0
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
[
16000
,
44100
],
)))
def
test_resample_identity
(
self
,
resampling_method
,
sample_rate
):
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
1
)
resampler
=
T
.
Resample
(
sample_rate
,
sample_rate
)
resampled
=
resampler
(
waveform
)
self
.
assertEqual
(
waveform
,
resampled
)
torchaudio/functional/functional.py
View file @
fad19fab
...
...
@@ -1449,6 +1449,9 @@ def resample(
assert
orig_freq
>
0.0
and
new_freq
>
0.0
if
orig_freq
==
new_freq
:
return
waveform
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
,
...
...
torchaudio/transforms.py
View file @
fad19fab
...
...
@@ -696,10 +696,11 @@ class Resample(torch.nn.Module):
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
rolloff
=
rolloff
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
lowpass_filter_width
,
self
.
rolloff
,
self
.
resampling_method
,
beta
)
self
.
register_buffer
(
'kernel'
,
kernel
)
if
self
.
orig_freq
!=
self
.
new_freq
:
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
lowpass_filter_width
,
self
.
rolloff
,
self
.
resampling_method
,
beta
)
self
.
register_buffer
(
'kernel'
,
kernel
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
...
...
@@ -709,6 +710,8 @@ class Resample(torch.nn.Module):
Returns:
Tensor: Output signal of dimension (..., time).
"""
if
self
.
orig_freq
==
self
.
new_freq
:
return
waveform
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment