Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
52e7bfd9
Unverified
Commit
52e7bfd9
authored
May 14, 2021
by
Caroline Chen
Committed by
GitHub
May 14, 2021
Browse files
Precompute transforms.Resample kernel (#1499)
parent
8a86c463
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
42 deletions
+69
-42
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+56
-39
torchaudio/transforms.py
torchaudio/transforms.py
+13
-3
No files found.
torchaudio/functional/functional.py
View file @
52e7bfd9
...
@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
...
@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
def
_get_sinc_resample_kernel
(
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
orig_freq
:
float
,
new_freq
:
int
,
new_freq
:
float
,
gcd
:
int
,
lowpass_filter_width
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
,
rolloff
:
float
):
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
warnings
.
warn
(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq
=
int
(
orig_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
assert
lowpass_filter_width
>
0
assert
lowpass_filter_width
>
0
kernels
=
[]
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
base_freq
=
min
(
orig_freq
,
new_freq
)
...
@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
...
@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
# future work.
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
device
=
device
,
dtype
=
dtype
)
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
)
for
i
in
range
(
new_freq
):
for
i
in
range
(
new_freq
):
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
...
@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
...
@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
def
_apply_sinc_resample_kernel
(
waveform
:
Tensor
,
orig_freq
:
float
,
new_freq
:
float
,
gcd
:
int
,
kernel
:
Tensor
,
width
:
int
,
):
orig_freq
=
int
(
orig_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
# pack batch
shape
=
waveform
.
size
()
waveform
=
waveform
.
view
(
-
1
,
shape
[
-
1
])
kernel
=
kernel
.
to
(
device
=
waveform
.
device
,
dtype
=
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
resampled
=
torch
.
nn
.
functional
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
resampled
=
resampled
[...,
:
target_length
]
# unpack batch
resampled
=
resampled
.
view
(
shape
[:
-
1
]
+
resampled
.
shape
[
-
1
:])
return
resampled
def
resample
(
def
resample
(
waveform
:
Tensor
,
waveform
:
Tensor
,
orig_freq
:
float
,
orig_freq
:
float
,
...
@@ -1380,42 +1424,15 @@ def resample(
...
@@ -1380,42 +1424,15 @@ def resample(
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension (..., time).
Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
"""
"""
# pack batch
shape
=
waveform
.
size
()
waveform
=
waveform
.
view
(
-
1
,
shape
[
-
1
])
assert
orig_freq
>
0.0
and
new_freq
>
0.0
assert
orig_freq
>
0.0
and
new_freq
>
0.0
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
warnings
.
warn
(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq
=
int
(
orig_freq
)
new_freq
=
int
(
new_freq
)
gcd
=
math
.
gcd
(
orig_freq
,
new_freq
)
orig_freq
=
orig_freq
//
gcd
new_freq
=
new_freq
//
gcd
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
,
waveform
.
device
,
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
resampled
=
torch
.
nn
.
functional
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
resampled
=
resampled
[...,
:
target_length
]
# unpack batch
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
)
resampled
=
resampled
.
view
(
shape
[:
-
1
]
+
resampled
.
shape
[
-
1
:]
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
return
resampled
return
resampled
torchaudio/transforms.py
View file @
52e7bfd9
...
@@ -8,6 +8,10 @@ import torch
...
@@ -8,6 +8,10 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torchaudio
import
functional
as
F
from
torchaudio
import
functional
as
F
from
.functional.functional
import
(
_get_sinc_resample_kernel
,
_apply_sinc_resample_kernel
,
)
__all__
=
[
__all__
=
[
'Spectrogram'
,
'Spectrogram'
,
...
@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
...
@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
orig_freq
:
in
t
=
16000
,
orig_freq
:
floa
t
=
16000
,
new_freq
:
in
t
=
16000
,
new_freq
:
floa
t
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
,
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
None
:
rolloff
:
float
=
0.99
)
->
None
:
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
self
.
new_freq
=
new_freq
self
.
gcd
=
math
.
gcd
(
int
(
self
.
orig_freq
),
int
(
self
.
new_freq
))
self
.
resampling_method
=
resampling_method
self
.
resampling_method
=
resampling_method
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
rolloff
=
rolloff
self
.
rolloff
=
rolloff
self
.
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
...
@@ -682,7 +691,8 @@ class Resample(torch.nn.Module):
...
@@ -682,7 +691,8 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
F
.
resample
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
)
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment