Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
c2740644
Unverified
Commit
c2740644
authored
May 10, 2021
by
Caroline Chen
Committed by
GitHub
May 10, 2021
Browse files
Add rolloff param to resample (#1488)
parent
32f661f0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
9 deletions
+28
-9
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+5
-2
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+13
-5
torchaudio/transforms.py
torchaudio/transforms.py
+10
-2
No files found.
torchaudio/compliance/kaldi.py
View file @
c2740644
...
@@ -755,7 +755,8 @@ def mfcc(
...
@@ -755,7 +755,8 @@ def mfcc(
def
resample_waveform
(
waveform
:
Tensor
,
def
resample_waveform
(
waveform
:
Tensor
,
orig_freq
:
float
,
orig_freq
:
float
,
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
)
->
Tensor
:
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency.
r
"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
This is a wrapper around ``torchaudio.functional.resample``.
...
@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
...
@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
new_freq (float): The desired frequency
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns:
Returns:
Tensor: The waveform at the new frequency
Tensor: The waveform at the new frequency
"""
"""
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
)
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
)
torchaudio/functional/functional.py
View file @
c2740644
...
@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch(
...
@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch(
return
result
return
result
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
def
_get_sinc_resample_kernel
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
assert
lowpass_filter_width
>
0
assert
lowpass_filter_width
>
0
kernels
=
[]
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
base_freq
=
min
(
orig_freq
,
new_freq
)
...
@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt
...
@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt
# At first I thought I only needed this when downsampling, but when upsampling
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
# which will add high freq artifacts.
base_freq
*=
0.99
base_freq
*=
rolloff
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# using the sinc interpolation formula:
...
@@ -1352,7 +1357,8 @@ def resample(
...
@@ -1352,7 +1357,8 @@ def resample(
waveform
:
Tensor
,
waveform
:
Tensor
,
orig_freq
:
float
,
orig_freq
:
float
,
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
,
)
->
Tensor
:
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
...
@@ -1369,6 +1375,8 @@ def resample(
...
@@ -1369,6 +1375,8 @@ def resample(
new_freq (float): The desired frequency
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension (..., time).
...
@@ -1386,7 +1394,7 @@ def resample(
...
@@ -1386,7 +1394,7 @@ def resample(
new_freq
=
new_freq
//
gcd
new_freq
=
new_freq
//
gcd
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
waveform
.
device
,
waveform
.
dtype
)
rolloff
,
waveform
.
device
,
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
num_wavs
,
length
=
waveform
.
shape
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
...
...
torchaudio/transforms.py
View file @
c2740644
...
@@ -640,16 +640,24 @@ class Resample(torch.nn.Module):
...
@@ -640,16 +640,24 @@ class Resample(torch.nn.Module):
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
orig_freq
:
int
=
16000
,
orig_freq
:
int
=
16000
,
new_freq
:
int
=
16000
,
new_freq
:
int
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
)
->
None
:
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
None
:
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
self
.
new_freq
=
new_freq
self
.
resampling_method
=
resampling_method
self
.
resampling_method
=
resampling_method
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
rolloff
=
rolloff
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
@@ -660,7 +668,7 @@ class Resample(torch.nn.Module):
...
@@ -660,7 +668,7 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
F
.
resample
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
)
return
F
.
resample
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment