Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e43a8e76
Unverified
Commit
e43a8e76
authored
Jan 19, 2021
by
Alexandre Défossez
Committed by
GitHub
Jan 19, 2021
Browse files
Make resampling simpler and faster (#1087)
parent
f1d8d1e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
201 deletions
+64
-201
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+64
-201
No files found.
torchaudio/compliance/kaldi.py
View file @
e43a8e76
...
...
@@ -3,6 +3,7 @@ from typing import Tuple
import
math
import
torch
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
import
torchaudio
import
torchaudio._internal.fft
...
...
@@ -752,141 +753,54 @@ def mfcc(
return
feature
def
_get_LR_indices_and_weights
(
orig_freq
:
float
,
new_freq
:
float
,
output_samples_in_unit
:
int
,
window_width
:
float
,
lowpass_cutoff
:
float
,
lowpass_filter_width
:
int
,
device
:
torch
.
device
,
dtype
:
int
)
->
Tuple
[
Tensor
,
Tensor
]:
r
"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
The reason why the same filter is not used for multiple convolutions is because the
sinc function could sampled at different points in time. For example, suppose
a signal is sampled at the timestamps (seconds)
0 16 32
and we want it to be sampled at the timestamps (seconds)
0 5 10 15 20 25 30 35
at the timestamp of 16, the delta timestamps are
16 11 6 1 4 9 14 19
at the timestamp of 32, the delta timestamps are
32 27 22 17 12 8 2 3
As we can see from deltas, the sinc function is sampled at different points of time
assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
for 16 vs [...., 2, 3, ....] for 32)
Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then
there needs to be one filter.
A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
has infinite support (non-zero for all values) so instead it is truncated and multiplied by
a window function which gives it less-than-perfect rolloff [1].
[1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm
Args:
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
output_samples_in_unit (int): The number of output samples in the smallest repeating unit:
num_samp_out = new_freq / Gcd(orig_freq, new_freq)
window_width (float): The width of the window which is nonzero
lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less
than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less
efficient. We suggest around 4 to 10 for normal use
Returns:
(Tensor, Tensor): A tuple of ``min_input_index`` (which is the minimum indices
where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert
lowpass_cutoff
<
min
(
orig_freq
,
new_freq
)
/
2
output_t
=
torch
.
arange
(
0.
,
output_samples_in_unit
,
device
=
device
,
dtype
=
dtype
)
/
new_freq
min_t
=
output_t
-
window_width
max_t
=
output_t
+
window_width
min_input_index
=
torch
.
ceil
(
min_t
*
orig_freq
)
# size (output_samples_in_unit)
max_input_index
=
torch
.
floor
(
max_t
*
orig_freq
)
# size (output_samples_in_unit)
num_indices
=
max_input_index
-
min_input_index
+
1
# size (output_samples_in_unit)
max_weight_width
=
num_indices
.
max
()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j
=
torch
.
arange
(
max_weight_width
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
0
)
input_index
=
min_input_index
.
unsqueeze
(
1
)
+
j
delta_t
=
(
input_index
/
orig_freq
)
-
output_t
.
unsqueeze
(
1
)
weights
=
torch
.
zeros_like
(
delta_t
)
inside_window_indices
=
delta_t
.
abs
().
lt
(
window_width
)
# raised-cosine (Hanning) window with width `window_width`
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
torch
.
cos
(
2
*
math
.
pi
*
lowpass_cutoff
/
lowpass_filter_width
*
delta_t
[
inside_window_indices
]))
t_eq_zero_indices
=
delta_t
.
eq
(
0.0
)
t_not_eq_zero_indices
=
~
t_eq_zero_indices
# sinc filter function
weights
[
t_not_eq_zero_indices
]
*=
torch
.
sin
(
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
])
/
(
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
# limit of the function at t = 0
weights
[
t_eq_zero_indices
]
*=
2
*
lowpass_cutoff
weights
/=
orig_freq
# size (output_samples_in_unit, max_weight_width)
return
min_input_index
,
weights
def
_lcm
(
a
:
int
,
b
:
int
)
->
int
:
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
def
_get_num_LR_output_samples
(
input_num_samp
:
int
,
samp_rate_in
:
float
,
samp_rate_out
:
float
)
->
int
:
r
"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
Args:
input_num_samp (int): The number of samples in the input
samp_rate_in (float): The original frequency of the signal
samp_rate_out (float): The desired frequency
Returns:
int: The number of output samples
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_rate_in and
# samp_rate_out.
samp_rate_in
=
int
(
samp_rate_in
)
samp_rate_out
=
int
(
samp_rate_out
)
tick_freq
=
_lcm
(
samp_rate_in
,
samp_rate_out
)
ticks_per_input_period
=
tick_freq
//
samp_rate_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_rate_in ).
interval_length_in_ticks
=
input_num_samp
*
ticks_per_input_period
if
interval_length_in_ticks
<=
0
:
return
0
ticks_per_output_period
=
tick_freq
//
samp_rate_out
# Get the last output-sample in the closed interval, i.e. replacing [ ) with
# [ ]. Note: integer division rounds down. See
# http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
# the notation.
last_output_samp
=
interval_length_in_ticks
//
ticks_per_output_period
# We need the last output-sample in the open interval, so if it takes us to
# the end of the interval exactly, subtract one.
if
last_output_samp
*
ticks_per_output_period
==
interval_length_in_ticks
:
last_output_samp
-=
1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp
=
last_output_samp
+
1
return
num_output_samp
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
assert
lowpass_filter_width
>
0
kernels
=
[]
base_freq
=
min
(
orig_freq
,
new_freq
)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq
*=
0.99
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width
=
math
.
ceil
(
lowpass_filter_width
*
orig_freq
/
base_freq
)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
device
=
device
,
dtype
=
dtype
)
for
i
in
range
(
new_freq
):
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
t
*=
math
.
pi
# we do not use torch.hann_window here as we need to evaluate the window
# at spectifics positions, not over a regular grid.
window
=
torch
.
cos
(
t
/
lowpass_filter_width
/
2
)
**
2
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
kernel
.
mul_
(
window
)
kernels
.
append
(
kernel
)
scale
=
base_freq
/
orig_freq
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
def
resample_waveform
(
waveform
:
Tensor
,
...
...
@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
device
,
dtype
=
waveform
.
device
,
waveform
.
dtype
assert
waveform
.
dim
()
==
2
assert
orig_freq
>
0.0
and
new_freq
>
0.0
min_freq
=
min
(
orig_freq
,
new_freq
)
lowpass_cutoff
=
0.99
*
0.5
*
min_freq
assert
lowpass_cutoff
*
2
<=
min_freq
base_freq
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
input_samples_in_unit
=
int
(
orig_freq
)
//
base_freq
output_samples_in_unit
=
int
(
new_freq
)
//
base_freq
window_width
=
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
first_indices
,
weights
=
_get_LR_indices_and_weights
(
orig_freq
,
new_freq
,
output_samples_in_unit
,
window_width
,
lowpass_cutoff
,
lowpass_filter_width
,
device
,
dtype
)
assert
first_indices
.
dim
()
==
1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
# all the weights have the same stride but have different padding.
# Current implementation takes the input and applies the various padding before
# doing a conv1d for that specific weight.
conv_stride
=
input_samples_in_unit
conv_transpose_stride
=
output_samples_in_unit
num_channels
,
wave_len
=
waveform
.
size
()
window_size
=
weights
.
size
(
1
)
tot_output_samp
=
_get_num_LR_output_samples
(
wave_len
,
orig_freq
,
new_freq
)
output
=
torch
.
zeros
((
num_channels
,
tot_output_samp
),
device
=
device
,
dtype
=
dtype
)
# eye size: (num_channels, num_channels, 1)
eye
=
torch
.
eye
(
num_channels
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
2
)
for
i
in
range
(
first_indices
.
size
(
0
)):
wave_to_conv
=
waveform
first_index
=
int
(
first_indices
[
i
].
item
())
if
first_index
>=
0
:
# trim the signal as the filter will not be applied before the first_index
wave_to_conv
=
wave_to_conv
[...,
first_index
:]
# pad the right of the signal to allow partial convolutions meaning compute
# values for partial windows (e.g. end of the window is outside the signal length)
max_unit_index
=
(
tot_output_samp
-
1
)
//
output_samples_in_unit
end_index_of_last_window
=
max_unit_index
*
conv_stride
+
window_size
current_wave_len
=
wave_len
-
first_index
right_padding
=
max
(
0
,
end_index_of_last_window
+
1
-
current_wave_len
)
left_padding
=
max
(
0
,
-
first_index
)
if
left_padding
!=
0
or
right_padding
!=
0
:
wave_to_conv
=
torch
.
nn
.
functional
.
pad
(
wave_to_conv
,
(
left_padding
,
right_padding
))
conv_wave
=
torch
.
nn
.
functional
.
conv1d
(
wave_to_conv
.
unsqueeze
(
0
),
weights
[
i
].
repeat
(
num_channels
,
1
,
1
),
stride
=
conv_stride
,
groups
=
num_channels
)
# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave
=
torch
.
nn
.
functional
.
conv_transpose1d
(
conv_wave
,
eye
,
stride
=
conv_transpose_stride
).
squeeze
(
0
)
# pad dilated_conv_wave so it reaches the output length if needed.
dialated_conv_wave_len
=
dilated_conv_wave
.
size
(
-
1
)
left_padding
=
i
right_padding
=
max
(
0
,
tot_output_samp
-
(
left_padding
+
dialated_conv_wave_len
))
dilated_conv_wave
=
torch
.
nn
.
functional
.
pad
(
dilated_conv_wave
,
(
left_padding
,
right_padding
))[...,
:
tot_output_samp
]
output
+=
dilated_conv_wave
return
output
orig_freq
=
int
(
orig_freq
)
new_freq
=
int
(
new_freq
)
gcd
=
math
.
gcd
(
orig_freq
,
new_freq
)
orig_freq
=
orig_freq
//
gcd
new_freq
=
new_freq
//
gcd
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
waveform
.
device
,
waveform
.
dtype
)
num_wavs
,
length
=
waveform
.
shape
waveform
=
F
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
resampled
=
F
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
return
resampled
[...,
:
target_length
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment