Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
c2740644
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "026bc29237a6478ea087362de83b854825d2e207"
Unverified
Commit
c2740644
authored
May 10, 2021
by
Caroline Chen
Committed by
GitHub
May 10, 2021
Browse files
Add rolloff param to resample (#1488)
parent
32f661f0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
9 deletions
+28
-9
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+5
-2
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+13
-5
torchaudio/transforms.py
torchaudio/transforms.py
+10
-2
No files found.
torchaudio/compliance/kaldi.py
View file @
c2740644
...
@@ -755,7 +755,8 @@ def mfcc(
...
@@ -755,7 +755,8 @@ def mfcc(
def
resample_waveform
(
waveform
:
Tensor
,
def
resample_waveform
(
waveform
:
Tensor
,
orig_freq
:
float
,
orig_freq
:
float
,
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
)
->
Tensor
:
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency.
r
"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
This is a wrapper around ``torchaudio.functional.resample``.
...
@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
...
@@ -766,8 +767,10 @@ def resample_waveform(waveform: Tensor,
new_freq (float): The desired frequency
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns:
Returns:
Tensor: The waveform at the new frequency
Tensor: The waveform at the new frequency
"""
"""
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
)
return
torchaudio
.
functional
.
resample
(
waveform
,
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
)
torchaudio/functional/functional.py
View file @
c2740644
...
@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch(
...
@@ -1298,8 +1298,13 @@ def compute_kaldi_pitch(
return
result
return
result
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
def
_get_sinc_resample_kernel
(
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
assert
lowpass_filter_width
>
0
assert
lowpass_filter_width
>
0
kernels
=
[]
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
base_freq
=
min
(
orig_freq
,
new_freq
)
...
@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt
...
@@ -1307,7 +1312,7 @@ def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_widt
# At first I thought I only needed this when downsampling, but when upsampling
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
# which will add high freq artifacts.
base_freq
*=
0.99
base_freq
*=
rolloff
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# using the sinc interpolation formula:
...
@@ -1352,7 +1357,8 @@ def resample(
...
@@ -1352,7 +1357,8 @@ def resample(
waveform
:
Tensor
,
waveform
:
Tensor
,
orig_freq
:
float
,
orig_freq
:
float
,
new_freq
:
float
,
new_freq
:
float
,
lowpass_filter_width
:
int
=
6
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
,
)
->
Tensor
:
)
->
Tensor
:
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
r
"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
...
@@ -1369,6 +1375,8 @@ def resample(
...
@@ -1369,6 +1375,8 @@ def resample(
new_freq (float): The desired frequency
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension (..., time).
...
@@ -1386,7 +1394,7 @@ def resample(
...
@@ -1386,7 +1394,7 @@ def resample(
new_freq
=
new_freq
//
gcd
new_freq
=
new_freq
//
gcd
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
waveform
.
device
,
waveform
.
dtype
)
rolloff
,
waveform
.
device
,
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
num_wavs
,
length
=
waveform
.
shape
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
...
...
torchaudio/transforms.py
View file @
c2740644
...
@@ -640,16 +640,24 @@ class Resample(torch.nn.Module):
...
@@ -640,16 +640,24 @@ class Resample(torch.nn.Module):
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
orig_freq
:
int
=
16000
,
orig_freq
:
int
=
16000
,
new_freq
:
int
=
16000
,
new_freq
:
int
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
)
->
None
:
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
None
:
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
self
.
new_freq
=
new_freq
self
.
resampling_method
=
resampling_method
self
.
resampling_method
=
resampling_method
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
rolloff
=
rolloff
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
...
@@ -660,7 +668,7 @@ class Resample(torch.nn.Module):
...
@@ -660,7 +668,7 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
F
.
resample
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
)
return
F
.
resample
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment