Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
25a8adf6
Unverified
Commit
25a8adf6
authored
Oct 13, 2021
by
Caroline Chen
Committed by
GitHub
Oct 13, 2021
Browse files
[BC-Breaking] Ensure integer input frequencies for resample (#1857)
parent
483d8fae
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
47 deletions
+25
-47
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+2
-21
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+3
-3
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+1
-1
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+15
-18
torchaudio/transforms.py
torchaudio/transforms.py
+4
-4
No files found.
test/torchaudio_unittest/functional/functional_impl.py
View file @
25a8adf6
...
@@ -27,10 +27,10 @@ class Functional(TestBaseMixin):
...
@@ -27,10 +27,10 @@ class Functional(TestBaseMixin):
new_sample_rate
=
sample_rate
new_sample_rate
=
sample_rate
if
up_scale_factor
is
not
None
:
if
up_scale_factor
is
not
None
:
new_sample_rate
*
=
up_scale_factor
new_sample_rate
=
int
(
new_sample_rate
*
up_scale_factor
)
if
down_scale_factor
is
not
None
:
if
down_scale_factor
is
not
None
:
new_sample_rate
//=
down_scale_factor
new_sample_rate
=
int
(
new_sample_rate
/
down_scale_factor
)
duration
=
5
# seconds
duration
=
5
# seconds
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
original_timestamps
=
torch
.
arange
(
0
,
duration
,
1.0
/
sample_rate
)
...
@@ -439,25 +439,6 @@ class Functional(TestBaseMixin):
...
@@ -439,25 +439,6 @@ class Functional(TestBaseMixin):
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
,
i
):
def
test_resample_waveform_upsample_accuracy
(
self
,
resampling_method
,
i
):
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
self
.
_test_resample_waveform_accuracy
(
up_scale_factor
=
1.0
+
i
/
20.0
,
resampling_method
=
resampling_method
)
def
test_resample_no_warning
(
self
):
sample_rate
=
44100
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
warnings
.
simplefilter
(
"always"
)
F
.
resample
(
waveform
,
float
(
sample_rate
),
sample_rate
/
2.
)
assert
len
(
w
)
==
0
def
test_resample_warning
(
self
):
"""resample should throw a warning if an input frequency is not of an integer value"""
sample_rate
=
44100
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.1
)
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
warnings
.
simplefilter
(
"always"
)
F
.
resample
(
waveform
,
sample_rate
,
5512.5
)
assert
len
(
w
)
==
1
@
nested_params
(
@
nested_params
(
[
0.5
,
1.01
,
1.3
],
[
0.5
,
1.01
,
1.3
],
[
True
,
False
],
[
True
,
False
],
...
...
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
25a8adf6
...
@@ -659,7 +659,7 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -659,7 +659,7 @@ class Functional(TempDirMixin, TestBaseMixin):
def
test_resample_sinc
(
self
):
def
test_resample_sinc
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
sr1
,
sr2
=
16000
.
,
8000
.
sr1
,
sr2
=
16000
,
8000
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"sinc_interpolation"
)
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"sinc_interpolation"
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
16000
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
16000
)
...
@@ -667,11 +667,11 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -667,11 +667,11 @@ class Functional(TempDirMixin, TestBaseMixin):
def
test_resample_kaiser
(
self
):
def
test_resample_kaiser
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
sr1
,
sr2
=
16000
.
,
8000
.
sr1
,
sr2
=
16000
,
8000
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"kaiser_window"
)
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"kaiser_window"
)
def
func_beta
(
tensor
):
def
func_beta
(
tensor
):
sr1
,
sr2
=
16000
.
,
8000
.
sr1
,
sr2
=
16000
,
8000
beta
=
6.
beta
=
6.
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"kaiser_window"
,
beta
=
beta
)
return
F
.
resample
(
tensor
,
sr1
,
sr2
,
resampling_method
=
"kaiser_window"
,
beta
=
beta
)
...
...
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
25a8adf6
...
@@ -84,7 +84,7 @@ class Transforms(TestBaseMixin):
...
@@ -84,7 +84,7 @@ class Transforms(TestBaseMixin):
def
test_Resample
(
self
):
def
test_Resample
(
self
):
sr1
,
sr2
=
16000
,
8000
sr1
,
sr2
=
16000
,
8000
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
sr1
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
sr1
)
self
.
_assert_consistency
(
T
.
Resample
(
float
(
sr1
)
,
float
(
sr2
)
)
,
tensor
)
self
.
_assert_consistency
(
T
.
Resample
(
sr1
,
sr2
),
tensor
)
def
test_ComplexNorm
(
self
):
def
test_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
...
...
torchaudio/functional/functional.py
View file @
25a8adf6
...
@@ -1471,8 +1471,8 @@ def compute_kaldi_pitch(
...
@@ -1471,8 +1471,8 @@ def compute_kaldi_pitch(
def
_get_sinc_resample_kernel
(
def
_get_sinc_resample_kernel
(
orig_freq
:
floa
t
,
orig_freq
:
in
t
,
new_freq
:
floa
t
,
new_freq
:
in
t
,
gcd
:
int
,
gcd
:
int
,
lowpass_filter_width
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
,
rolloff
:
float
,
...
@@ -1482,16 +1482,13 @@ def _get_sinc_resample_kernel(
...
@@ -1482,16 +1482,13 @@ def _get_sinc_resample_kernel(
dtype
:
Optional
[
torch
.
dtype
]
=
None
):
dtype
:
Optional
[
torch
.
dtype
]
=
None
):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
warnings
.
warn
(
raise
Exception
(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"Frequencies must be of integer type to ensure quality resampling computation. "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"To work around this, manually convert both frequencies to integer values "
"Using non-integer valued frequencies will throw an error in release 0.10. "
"that maintain their resampling rate ratio before passing them into the function. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
"For more information or to leave feedback about this change, please refer to "
"For more information, please refer to https://github.com/pytorch/audio/issues/1487."
"https://github.com/pytorch/audio/issues/1487."
)
)
if
resampling_method
not
in
[
'sinc_interpolation'
,
'kaiser_window'
]:
if
resampling_method
not
in
[
'sinc_interpolation'
,
'kaiser_window'
]:
...
@@ -1562,8 +1559,8 @@ def _get_sinc_resample_kernel(
...
@@ -1562,8 +1559,8 @@ def _get_sinc_resample_kernel(
def
_apply_sinc_resample_kernel
(
def
_apply_sinc_resample_kernel
(
waveform
:
Tensor
,
waveform
:
Tensor
,
orig_freq
:
floa
t
,
orig_freq
:
in
t
,
new_freq
:
floa
t
,
new_freq
:
in
t
,
gcd
:
int
,
gcd
:
int
,
kernel
:
Tensor
,
kernel
:
Tensor
,
width
:
int
,
width
:
int
,
...
@@ -1589,8 +1586,8 @@ def _apply_sinc_resample_kernel(
...
@@ -1589,8 +1586,8 @@ def _apply_sinc_resample_kernel(
def
resample
(
def
resample
(
waveform
:
Tensor
,
waveform
:
Tensor
,
orig_freq
:
floa
t
,
orig_freq
:
in
t
,
new_freq
:
floa
t
,
new_freq
:
in
t
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
,
rolloff
:
float
=
0.99
,
resampling_method
:
str
=
"sinc_interpolation"
,
resampling_method
:
str
=
"sinc_interpolation"
,
...
@@ -1606,8 +1603,8 @@ def resample(
...
@@ -1606,8 +1603,8 @@ def resample(
Args:
Args:
waveform (Tensor): The input signal of dimension `(..., time)`
waveform (Tensor): The input signal of dimension `(..., time)`
orig_freq (
floa
t): The original frequency of the signal
orig_freq (
in
t): The original frequency of the signal
new_freq (
floa
t): The desired frequency
new_freq (
in
t): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``)
but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
...
@@ -1736,7 +1733,7 @@ def pitch_shift(
...
@@ -1736,7 +1733,7 @@ def pitch_shift(
win_length
=
win_length
,
win_length
=
win_length
,
window
=
window
,
window
=
window
,
length
=
len_stretch
)
length
=
len_stretch
)
waveform_shift
=
resample
(
waveform_stretch
,
sample_rate
/
/
rate
,
float
(
sample_rate
)
)
waveform_shift
=
resample
(
waveform_stretch
,
int
(
sample_rate
/
rate
)
,
sample_rate
)
shift_len
=
waveform_shift
.
size
()[
-
1
]
shift_len
=
waveform_shift
.
size
()[
-
1
]
if
shift_len
>
ori_len
:
if
shift_len
>
ori_len
:
waveform_shift
=
waveform_shift
[...,
:
ori_len
]
waveform_shift
=
waveform_shift
[...,
:
ori_len
]
...
...
torchaudio/transforms.py
View file @
25a8adf6
...
@@ -815,8 +815,8 @@ class Resample(torch.nn.Module):
...
@@ -815,8 +815,8 @@ class Resample(torch.nn.Module):
Alternatively, you could rewrite a transform that caches a higher precision kernel.
Alternatively, you could rewrite a transform that caches a higher precision kernel.
Args:
Args:
orig_freq (
floa
t, optional): The original frequency of the signal. (Default: ``16000``)
orig_freq (
in
t, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (
floa
t, optional): The desired frequency. (Default: ``16000``)
new_freq (
in
t, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method to use.
resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
...
@@ -840,8 +840,8 @@ class Resample(torch.nn.Module):
...
@@ -840,8 +840,8 @@ class Resample(torch.nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
orig_freq
:
floa
t
=
16000
,
orig_freq
:
in
t
=
16000
,
new_freq
:
floa
t
=
16000
,
new_freq
:
in
t
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
,
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
,
rolloff
:
float
=
0.99
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment