Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
52e7bfd9
"llm/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "df6dc4fd96ba485a028bb1a59e63500bb7357247"
Unverified
Commit
52e7bfd9
authored
May 14, 2021
by
Caroline Chen
Committed by
GitHub
May 14, 2021
Browse files
Precompute transforms.Resample kernel (#1499)
parent
8a86c463
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
69 additions
and
42 deletions
+69
-42
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+56
-39
torchaudio/transforms.py
torchaudio/transforms.py
+13
-3
No files found.
torchaudio/functional/functional.py
View file @
52e7bfd9
...
@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
...
@@ -1299,12 +1299,28 @@ def compute_kaldi_pitch(
def
_get_sinc_resample_kernel
(
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
orig_freq
:
float
,
new_freq
:
int
,
new_freq
:
float
,
gcd
:
int
,
lowpass_filter_width
:
int
,
lowpass_filter_width
:
int
,
rolloff
:
float
,
rolloff
:
float
):
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
warnings
.
warn
(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq
=
int
(
orig_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
assert
lowpass_filter_width
>
0
assert
lowpass_filter_width
>
0
kernels
=
[]
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
base_freq
=
min
(
orig_freq
,
new_freq
)
...
@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
...
@@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel(
# they will have a lot of almost zero values to the left or to the right...
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
# future work.
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
device
=
device
,
dtype
=
dtype
)
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
)
for
i
in
range
(
new_freq
):
for
i
in
range
(
new_freq
):
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
...
@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
...
@@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel(
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
def
_apply_sinc_resample_kernel
(
waveform
:
Tensor
,
orig_freq
:
float
,
new_freq
:
float
,
gcd
:
int
,
kernel
:
Tensor
,
width
:
int
,
):
orig_freq
=
int
(
orig_freq
)
//
gcd
new_freq
=
int
(
new_freq
)
//
gcd
# pack batch
shape
=
waveform
.
size
()
waveform
=
waveform
.
view
(
-
1
,
shape
[
-
1
])
kernel
=
kernel
.
to
(
device
=
waveform
.
device
,
dtype
=
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
resampled
=
torch
.
nn
.
functional
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
resampled
=
resampled
[...,
:
target_length
]
# unpack batch
resampled
=
resampled
.
view
(
shape
[:
-
1
]
+
resampled
.
shape
[
-
1
:])
return
resampled
def
resample
(
def
resample
(
waveform
:
Tensor
,
waveform
:
Tensor
,
orig_freq
:
float
,
orig_freq
:
float
,
...
@@ -1380,42 +1424,15 @@ def resample(
...
@@ -1380,42 +1424,15 @@ def resample(
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension (..., time).
Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
more efficient computation if resampling multiple waveforms with the same resampling parameters.
"""
"""
# pack batch
shape
=
waveform
.
size
()
waveform
=
waveform
.
view
(
-
1
,
shape
[
-
1
])
assert
orig_freq
>
0.0
and
new_freq
>
0.0
assert
orig_freq
>
0.0
and
new_freq
>
0.0
if
not
(
int
(
orig_freq
)
==
orig_freq
and
int
(
new_freq
)
==
new_freq
):
gcd
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
warnings
.
warn
(
"Non-integer frequencies are being cast to ints and may result in poor resampling quality "
"because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. "
"Using non-integer valued frequencies will throw an error in the next release. "
"To work around this issue, manually convert both frequencies to integer values "
"that maintain their resampling rate ratio before passing them into the function "
"Example: To downsample a 44100 hz waveform by a factor of 8, use "
"`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` "
"For more information or to leave feedback about this change, please refer to "
"https://github.com/pytorch/audio/issues/1487."
)
orig_freq
=
int
(
orig_freq
)
new_freq
=
int
(
new_freq
)
gcd
=
math
.
gcd
(
orig_freq
,
new_freq
)
orig_freq
=
orig_freq
//
gcd
new_freq
=
new_freq
//
gcd
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
rolloff
,
waveform
.
device
,
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
waveform
=
torch
.
nn
.
functional
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
resampled
=
torch
.
nn
.
functional
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
resampled
=
resampled
[...,
:
target_length
]
# unpack batch
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
gcd
,
lowpass_filter_width
,
rolloff
)
resampled
=
resampled
.
view
(
shape
[:
-
1
]
+
resampled
.
shape
[
-
1
:]
)
resampled
=
_apply_sinc_resample_kernel
(
waveform
,
orig_freq
,
new_freq
,
gcd
,
kernel
,
width
)
return
resampled
return
resampled
torchaudio/transforms.py
View file @
52e7bfd9
...
@@ -8,6 +8,10 @@ import torch
...
@@ -8,6 +8,10 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torchaudio
import
functional
as
F
from
torchaudio
import
functional
as
F
from
.functional.functional
import
(
_get_sinc_resample_kernel
,
_apply_sinc_resample_kernel
,
)
__all__
=
[
__all__
=
[
'Spectrogram'
,
'Spectrogram'
,
...
@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
...
@@ -661,18 +665,23 @@ class Resample(torch.nn.Module):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
orig_freq
:
in
t
=
16000
,
orig_freq
:
floa
t
=
16000
,
new_freq
:
in
t
=
16000
,
new_freq
:
floa
t
=
16000
,
resampling_method
:
str
=
'sinc_interpolation'
,
resampling_method
:
str
=
'sinc_interpolation'
,
lowpass_filter_width
:
int
=
6
,
lowpass_filter_width
:
int
=
6
,
rolloff
:
float
=
0.99
)
->
None
:
rolloff
:
float
=
0.99
)
->
None
:
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
self
.
new_freq
=
new_freq
self
.
new_freq
=
new_freq
self
.
gcd
=
math
.
gcd
(
int
(
self
.
orig_freq
),
int
(
self
.
new_freq
))
self
.
resampling_method
=
resampling_method
self
.
resampling_method
=
resampling_method
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
lowpass_filter_width
=
lowpass_filter_width
self
.
rolloff
=
rolloff
self
.
rolloff
=
rolloff
self
.
kernel
,
self
.
width
=
_get_sinc_resample_kernel
(
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
...
@@ -682,7 +691,8 @@ class Resample(torch.nn.Module):
...
@@ -682,7 +691,8 @@ class Resample(torch.nn.Module):
Tensor: Output signal of dimension (..., time).
Tensor: Output signal of dimension (..., time).
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
if
self
.
resampling_method
==
'sinc_interpolation'
:
return
F
.
resample
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
lowpass_filter_width
,
self
.
rolloff
)
return
_apply_sinc_resample_kernel
(
waveform
,
self
.
orig_freq
,
self
.
new_freq
,
self
.
gcd
,
self
.
kernel
,
self
.
width
)
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
raise
ValueError
(
'Invalid resampling method: {}'
.
format
(
self
.
resampling_method
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment