Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
fad19fab
"vscode:/vscode.git/clone" did not exist on "290b1cb15b3ebff8f9a383f7577b7ee33114fc16"
Unverified
Commit
fad19fab
authored
Jun 01, 2021
by
Caroline Chen
Committed by
GitHub
Jun 01, 2021
Browse files
Ensure resampling identity is unchanged (#1537)
parent
f1a0b605
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
29 deletions
+48
-29
test/torchaudio_unittest/compliance_kaldi_test.py
test/torchaudio_unittest/compliance_kaldi_test.py
+15
-25
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+10
-0
test/torchaudio_unittest/transforms/transforms_test_impl.py
test/torchaudio_unittest/transforms/transforms_test_impl.py
+13
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+3
-0
torchaudio/transforms.py
torchaudio/transforms.py
+7
-4
No files found.
test/torchaudio_unittest/compliance_kaldi_test.py
View file @
fad19fab
...
@@ -56,15 +56,12 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -56,15 +56,12 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
().
setUp
()
super
().
setUp
()
#
1.
test signal for testing resampling
# test signal for testing resampling
self
.
test
1
_signal_sr
=
16000
self
.
test_signal_sr
=
16000
self
.
test
1
_signal
=
common_utils
.
get_whitenoise
(
self
.
test_signal
=
common_utils
.
get_whitenoise
(
sample_rate
=
self
.
test
1
_signal_sr
,
duration
=
0.5
,
sample_rate
=
self
.
test_signal_sr
,
duration
=
0.5
,
)
)
# 2. test audio file corresponding to saved kaldi ark files
self
.
test2_filepath
=
common_utils
.
get_asset_path
(
'kaldi_file_8000.wav'
)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for
f
in
os
.
listdir
(
kaldi_output_dir
):
for
f
in
os
.
listdir
(
kaldi_output_dir
):
dash_idx
=
f
.
find
(
'-'
)
dash_idx
=
f
.
find
(
'-'
)
...
@@ -176,30 +173,23 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -176,30 +173,23 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
# Passing in an empty tensor should result in an error
# Passing in an empty tensor should result in an error
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
self
.
assertRaises
(
AssertionError
,
kaldi
.
mfcc
,
torch
.
empty
(
0
))
def
test_resample_waveform
(
self
):
def
get_output_fn
(
sound
,
args
):
output
=
kaldi
.
resample_waveform
(
sound
.
to
(
torch
.
float32
),
args
[
1
],
args
[
2
])
return
output
self
.
_compliance_test_helper
(
self
.
test2_filepath
,
'resample'
,
32
,
3
,
get_output_fn
,
atol
=
1e-2
,
rtol
=
1e-5
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
def
test_resample_waveform_upsample_size
(
self
,
resampling_method
):
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test
1
_signal
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
*
2
,
upsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
*
2
,
resampling_method
=
resampling_method
)
resampling_method
=
resampling_method
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test
1
_signal
.
size
(
-
1
)
*
2
)
self
.
assertTrue
(
upsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
*
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
def
test_resample_waveform_downsample_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test
1
_signal
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
//
2
,
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test
1
_signal
.
size
(
-
1
)
//
2
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
)
//
2
)
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
@
parameterized
.
expand
([(
"sinc_interpolation"
),
(
"kaiser_window"
)])
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
def
test_resample_waveform_identity_size
(
self
,
resampling_method
):
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test
1
_signal
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
,
downsample_sound
=
kaldi
.
resample_waveform
(
self
.
test_signal
,
self
.
test_signal_sr
,
self
.
test_signal_sr
,
resampling_method
=
resampling_method
)
resampling_method
=
resampling_method
)
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test
1
_signal
.
size
(
-
1
))
self
.
assertTrue
(
downsample_sound
.
size
(
-
1
)
==
self
.
test_signal
.
size
(
-
1
))
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
def
_test_resample_waveform_accuracy
(
self
,
up_scale_factor
=
None
,
down_scale_factor
=
None
,
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
resampling_method
=
"sinc_interpolation"
,
atol
=
1e-1
,
rtol
=
1e-4
):
...
@@ -244,18 +234,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
...
@@ -244,18 +234,18 @@ class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
def
test_resample_waveform_multi_channel
(
self
,
resampling_method
):
num_channels
=
3
num_channels
=
3
multi_sound
=
self
.
test
1
_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
multi_sound
=
self
.
test_signal
.
repeat
(
num_channels
,
1
)
# (num_channels, 8000 smp)
for
i
in
range
(
num_channels
):
for
i
in
range
(
num_channels
):
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound
[
i
,
:]
*=
(
i
+
1
)
*
1.5
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test
1
_signal_sr
,
self
.
test
1
_signal_sr
//
2
,
multi_sound_sampled
=
kaldi
.
resample_waveform
(
multi_sound
,
self
.
test_signal_sr
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
resampling_method
=
resampling_method
)
# check that sampling is same whether using separately or in a tensor of size (c, n)
# check that sampling is same whether using separately or in a tensor of size (c, n)
for
i
in
range
(
num_channels
):
for
i
in
range
(
num_channels
):
single_channel
=
self
.
test
1
_signal
*
(
i
+
1
)
*
1.5
single_channel
=
self
.
test_signal
*
(
i
+
1
)
*
1.5
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test
1
_signal_sr
,
single_channel_sampled
=
kaldi
.
resample_waveform
(
single_channel
,
self
.
test_signal_sr
,
self
.
test
1
_signal_sr
//
2
,
self
.
test_signal_sr
//
2
,
resampling_method
=
resampling_method
)
resampling_method
=
resampling_method
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
self
.
assertEqual
(
multi_sound_sampled
[
i
,
:],
single_channel_sampled
[
0
],
rtol
=
1e-4
,
atol
=
1e-7
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
fad19fab
...
@@ -259,6 +259,16 @@ class Functional(TestBaseMixin):
...
@@ -259,6 +259,16 @@ class Functional(TestBaseMixin):
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
[
16000
,
44100
],
)))
def
test_resample_identity
(
self
,
resampling_method
,
sample_rate
):
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
1
)
resampled
=
F
.
resample
(
waveform
,
sample_rate
,
sample_rate
)
self
.
assertEqual
(
waveform
,
resampled
)
def
test_resample_no_warning
(
self
):
def
test_resample_no_warning
(
self
):
sample_rate
=
44100
sample_rate
=
44100
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
...
...
test/torchaudio_unittest/transforms/transforms_test_impl.py
View file @
fad19fab
import
itertools
import
warnings
import
warnings
import
torch
import
torch
...
@@ -8,6 +9,7 @@ from torchaudio_unittest.common_utils import (
...
@@ -8,6 +9,7 @@ from torchaudio_unittest.common_utils import (
get_whitenoise
,
get_whitenoise
,
get_spectrogram
,
get_spectrogram
,
)
)
from
parameterized
import
parameterized
def
_get_ratio
(
mat
):
def
_get_ratio
(
mat
):
...
@@ -77,3 +79,14 @@ class TransformsTestBase(TestBaseMixin):
...
@@ -77,3 +79,14 @@ class TransformsTestBase(TestBaseMixin):
warnings
.
simplefilter
(
"always"
)
warnings
.
simplefilter
(
"always"
)
T
.
MelScale
(
n_mels
=
64
,
sample_rate
=
8000
,
n_stft
=
201
)
T
.
MelScale
(
n_mels
=
64
,
sample_rate
=
8000
,
n_stft
=
201
)
assert
len
(
caught_warnings
)
==
0
assert
len
(
caught_warnings
)
==
0
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
"sinc_interpolation"
,
"kaiser_window"
],
[
16000
,
44100
],
)))
def
test_resample_identity
(
self
,
resampling_method
,
sample_rate
):
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
1
)
resampler
=
T
.
Resample
(
sample_rate
,
sample_rate
)
resampled
=
resampler
(
waveform
)
self
.
assertEqual
(
waveform
,
resampled
)
torchaudio/functional/functional.py
View file @
fad19fab
...
@@ -1449,6 +1449,9 @@ def resample(
...
@@ -1449,6 +1449,9 @@ def resample(
assert
orig_freq
>
0.0
and
new_freq
>
0.0
assert
orig_freq
>
0.0
and
new_freq
>
0.0
if
orig_freq
==
new_freq
:
return
waveform
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
,
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
,
...
...
torchaudio/transforms.py
View file @
fad19fab
...
@@ -696,10 +696,11 @@ class Resample(torch.nn.Module):
...
@@ -696,10 +696,11 @@ class Resample(torch.nn.Module):
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
rolloff
=
rolloff
self
.
rolloff
=
rolloff
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
if
self
.
orig_freq
!=
self
.
new_freq
:
self
.
lowpass_filter_width
,
self
.
rolloff
,
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
resampling_method
,
beta
)
self
.
lowpass_filter_width
,
self
.
rolloff
,
self
.
register_buffer
(
'kernel'
,
kernel
)
self
.
resampling_method
,
beta
)
self
.
register_buffer
(
'kernel'
,
kernel
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
@@ -709,6 +710,8 @@ class Resample(torch.nn.Module):
...
@@ -709,6 +710,8 @@ class Resample(torch.nn.Module):
Returns:
Returns:
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
orig_freq
==
self
.
new_freq
:
return
waveform
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
)
self
.
kernel
,
self
.
width
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment