Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e43a8e76
Unverified
Commit
e43a8e76
authored
Jan 19, 2021
by
Alexandre Défossez
Committed by
GitHub
Jan 19, 2021
Browse files
Make resampling simpler and faster (#1087)
parent
f1d8d1e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
201 deletions
+64
-201
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+64
-201
No files found.
torchaudio/compliance/kaldi.py
View file @
e43a8e76
...
@@ -3,6 +3,7 @@ from typing import Tuple
...
@@ -3,6 +3,7 @@ from typing import Tuple
import
math
import
math
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
import
torchaudio
import
torchaudio
import
torchaudio._internal.fft
import
torchaudio._internal.fft
...
@@ -752,141 +753,54 @@ def mfcc(
...
@@ -752,141 +753,54 @@ def mfcc(
return
feature
return
feature
def
_get_LR_indices_and_weights
(
orig_freq
:
float
,
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
new_freq
:
float
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
output_samples_in_unit
:
int
,
assert
lowpass_filter_width
>
0
window_width
:
float
,
kernels
=
[]
lowpass_cutoff
:
float
,
base_freq
=
min
(
orig_freq
,
new_freq
)
lowpass_filter_width
:
int
,
# This will perform antialiasing filtering by removing the highest frequencies.
device
:
torch
.
device
,
# At first I thought I only needed this when downsampling, but when upsampling
dtype
:
int
)
->
Tuple
[
Tensor
,
Tensor
]:
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
r
"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
# which will add high freq artifacts.
resampling as well as the indices in which they are valid. LinearResample (LR) means
base_freq
*=
0.99
that the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
the signal.
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
The reason why the same filter is not used for multiple convolutions is because the
# We can then sample the function x(t) with a different sample rate:
sinc function could sampled at different points in time. For example, suppose
# y[j] = x(j / new_freq)
a signal is sampled at the timestamps (seconds)
# or,
0 16 32
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
and we want it to be sampled at the timestamps (seconds)
0 5 10 15 20 25 30 35
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
at the timestamp of 16, the delta timestamps are
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
16 11 6 1 4 9 14 19
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
at the timestamp of 32, the delta timestamps are
# Indeed:
32 27 22 17 12 8 2 3
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
As we can see from deltas, the sinc function is sampled at different points of time
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
for 16 vs [...., 2, 3, ....] for 32)
# This will explain the F.conv1d after, with a stride of orig_freq.
width
=
math
.
ceil
(
lowpass_filter_width
*
orig_freq
/
base_freq
)
Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
there needs to be one filter.
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
# future work.
has infinite support (non-zero for all values) so instead it is truncated and multiplied by
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
device
=
device
,
dtype
=
dtype
)
a window function which gives it less-than-perfect rolloff [1].
for
i
in
range
(
new_freq
):
[1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
Args:
t
*=
math
.
pi
orig_freq (float): The original frequency of the signal
# we do not use torch.hann_window here as we need to evaluate the window
new_freq (float): The desired frequency
# at spectifics positions, not over a regular grid.
output_samples_in_unit (int): The number of output samples in the smallest repeating unit:
window
=
torch
.
cos
(
t
/
lowpass_filter_width
/
2
)
**
2
num_samp_out = new_freq / Gcd(orig_freq, new_freq)
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
window_width (float): The width of the window which is nonzero
kernel
.
mul_
(
window
)
lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less
kernels
.
append
(
kernel
)
than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less
scale
=
base_freq
/
orig_freq
efficient. We suggest around 4 to 10 for normal use
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
Returns:
(Tensor, Tensor): A tuple of ``min_input_index`` (which is the minimum indices
where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert
lowpass_cutoff
<
min
(
orig_freq
,
new_freq
)
/
2
output_t
=
torch
.
arange
(
0.
,
output_samples_in_unit
,
device
=
device
,
dtype
=
dtype
)
/
new_freq
min_t
=
output_t
-
window_width
max_t
=
output_t
+
window_width
min_input_index
=
torch
.
ceil
(
min_t
*
orig_freq
)
# size (output_samples_in_unit)
max_input_index
=
torch
.
floor
(
max_t
*
orig_freq
)
# size (output_samples_in_unit)
num_indices
=
max_input_index
-
min_input_index
+
1
# size (output_samples_in_unit)
max_weight_width
=
num_indices
.
max
()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j
=
torch
.
arange
(
max_weight_width
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
0
)
input_index
=
min_input_index
.
unsqueeze
(
1
)
+
j
delta_t
=
(
input_index
/
orig_freq
)
-
output_t
.
unsqueeze
(
1
)
weights
=
torch
.
zeros_like
(
delta_t
)
inside_window_indices
=
delta_t
.
abs
().
lt
(
window_width
)
# raised-cosine (Hanning) window with width `window_width`
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
torch
.
cos
(
2
*
math
.
pi
*
lowpass_cutoff
/
lowpass_filter_width
*
delta_t
[
inside_window_indices
]))
t_eq_zero_indices
=
delta_t
.
eq
(
0.0
)
t_not_eq_zero_indices
=
~
t_eq_zero_indices
# sinc filter function
weights
[
t_not_eq_zero_indices
]
*=
torch
.
sin
(
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
])
/
(
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
# limit of the function at t = 0
weights
[
t_eq_zero_indices
]
*=
2
*
lowpass_cutoff
weights
/=
orig_freq
# size (output_samples_in_unit, max_weight_width)
return
min_input_index
,
weights
def
_lcm
(
a
:
int
,
b
:
int
)
->
int
:
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
def
_get_num_LR_output_samples
(
input_num_samp
:
int
,
samp_rate_in
:
float
,
samp_rate_out
:
float
)
->
int
:
r
"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
Args:
input_num_samp (int): The number of samples in the input
samp_rate_in (float): The original frequency of the signal
samp_rate_out (float): The desired frequency
Returns:
int: The number of output samples
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_rate_in and
# samp_rate_out.
samp_rate_in
=
int
(
samp_rate_in
)
samp_rate_out
=
int
(
samp_rate_out
)
tick_freq
=
_lcm
(
samp_rate_in
,
samp_rate_out
)
ticks_per_input_period
=
tick_freq
//
samp_rate_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_rate_in ).
interval_length_in_ticks
=
input_num_samp
*
ticks_per_input_period
if
interval_length_in_ticks
<=
0
:
return
0
ticks_per_output_period
=
tick_freq
//
samp_rate_out
# Get the last output-sample in the closed interval, i.e. replacing [ ) with
# [ ]. Note: integer division rounds down. See
# http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
# the notation.
last_output_samp
=
interval_length_in_ticks
//
ticks_per_output_period
# We need the last output-sample in the open interval, so if it takes us to
# the end of the interval exactly, subtract one.
if
last_output_samp
*
ticks_per_output_period
==
interval_length_in_ticks
:
last_output_samp
-=
1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp
=
last_output_samp
+
1
return
num_output_samp
def
resample_waveform
(
waveform
:
Tensor
,
def
resample_waveform
(
waveform
:
Tensor
,
...
@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
...
@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
Returns:
Returns:
Tensor: The waveform at the new frequency
Tensor: The waveform at the new frequency
"""
"""
device
,
dtype
=
waveform
.
device
,
waveform
.
dtype
assert
waveform
.
dim
()
==
2
assert
waveform
.
dim
()
==
2
assert
orig_freq
>
0.0
and
new_freq
>
0.0
assert
orig_freq
>
0.0
and
new_freq
>
0.0
min_freq
=
min
(
orig_freq
,
new_freq
)
orig_freq
=
int
(
orig_freq
)
lowpass_cutoff
=
0.99
*
0.5
*
min_freq
new_freq
=
int
(
new_freq
)
gcd
=
math
.
gcd
(
orig_freq
,
new_freq
)
assert
lowpass_cutoff
*
2
<=
min_freq
orig_freq
=
orig_freq
//
gcd
new_freq
=
new_freq
//
gcd
base_freq
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
input_samples_in_unit
=
int
(
orig_freq
)
//
base_freq
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
output_samples_in_unit
=
int
(
new_freq
)
//
base_freq
waveform
.
device
,
waveform
.
dtype
)
window_width
=
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
num_wavs
,
length
=
waveform
.
shape
first_indices
,
weights
=
_get_LR_indices_and_weights
(
waveform
=
F
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
orig_freq
,
new_freq
,
output_samples_in_unit
,
resampled
=
F
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
window_width
,
lowpass_cutoff
,
lowpass_filter_width
,
device
,
dtype
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
assert
first_indices
.
dim
()
==
1
return
resampled
[...,
:
target_length
]
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
# all the weights have the same stride but have different padding.
# Current implementation takes the input and applies the various padding before
# doing a conv1d for that specific weight.
conv_stride
=
input_samples_in_unit
conv_transpose_stride
=
output_samples_in_unit
num_channels
,
wave_len
=
waveform
.
size
()
window_size
=
weights
.
size
(
1
)
tot_output_samp
=
_get_num_LR_output_samples
(
wave_len
,
orig_freq
,
new_freq
)
output
=
torch
.
zeros
((
num_channels
,
tot_output_samp
),
device
=
device
,
dtype
=
dtype
)
# eye size: (num_channels, num_channels, 1)
eye
=
torch
.
eye
(
num_channels
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
2
)
for
i
in
range
(
first_indices
.
size
(
0
)):
wave_to_conv
=
waveform
first_index
=
int
(
first_indices
[
i
].
item
())
if
first_index
>=
0
:
# trim the signal as the filter will not be applied before the first_index
wave_to_conv
=
wave_to_conv
[...,
first_index
:]
# pad the right of the signal to allow partial convolutions meaning compute
# values for partial windows (e.g. end of the window is outside the signal length)
max_unit_index
=
(
tot_output_samp
-
1
)
//
output_samples_in_unit
end_index_of_last_window
=
max_unit_index
*
conv_stride
+
window_size
current_wave_len
=
wave_len
-
first_index
right_padding
=
max
(
0
,
end_index_of_last_window
+
1
-
current_wave_len
)
left_padding
=
max
(
0
,
-
first_index
)
if
left_padding
!=
0
or
right_padding
!=
0
:
wave_to_conv
=
torch
.
nn
.
functional
.
pad
(
wave_to_conv
,
(
left_padding
,
right_padding
))
conv_wave
=
torch
.
nn
.
functional
.
conv1d
(
wave_to_conv
.
unsqueeze
(
0
),
weights
[
i
].
repeat
(
num_channels
,
1
,
1
),
stride
=
conv_stride
,
groups
=
num_channels
)
# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave
=
torch
.
nn
.
functional
.
conv_transpose1d
(
conv_wave
,
eye
,
stride
=
conv_transpose_stride
).
squeeze
(
0
)
# pad dilated_conv_wave so it reaches the output length if needed.
dialated_conv_wave_len
=
dilated_conv_wave
.
size
(
-
1
)
left_padding
=
i
right_padding
=
max
(
0
,
tot_output_samp
-
(
left_padding
+
dialated_conv_wave_len
))
dilated_conv_wave
=
torch
.
nn
.
functional
.
pad
(
dilated_conv_wave
,
(
left_padding
,
right_padding
))[...,
:
tot_output_samp
]
output
+=
dilated_conv_wave
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment