Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e43a8e76
"...text-generation-inference.git" did not exist on "c99ecd77ecc079a67c176b46b61c7a2d85ac068f"
Unverified
Commit
e43a8e76
authored
Jan 19, 2021
by
Alexandre Défossez
Committed by
GitHub
Jan 19, 2021
Browse files
Make resampling simpler and faster (#1087)
parent
f1d8d1e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
201 deletions
+64
-201
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+64
-201
No files found.
torchaudio/compliance/kaldi.py
View file @
e43a8e76
...
@@ -3,6 +3,7 @@ from typing import Tuple
...
@@ -3,6 +3,7 @@ from typing import Tuple
import
math
import
math
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
import
torchaudio
import
torchaudio
import
torchaudio._internal.fft
import
torchaudio._internal.fft
...
@@ -752,141 +753,54 @@ def mfcc(
...
@@ -752,141 +753,54 @@ def mfcc(
return
feature
return
feature
def
_get_LR_indices_and_weights
(
orig_freq
:
float
,
def
_get_sinc_resample_kernel
(
orig_freq
:
int
,
new_freq
:
int
,
lowpass_filter_width
:
int
,
new_freq
:
float
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
output_samples_in_unit
:
int
,
assert
lowpass_filter_width
>
0
window_width
:
float
,
kernels
=
[]
lowpass_cutoff
:
float
,
base_freq
=
min
(
orig_freq
,
new_freq
)
lowpass_filter_width
:
int
,
# This will perform antialiasing filtering by removing the highest frequencies.
device
:
torch
.
device
,
# At first I thought I only needed this when downsampling, but when upsampling
dtype
:
int
)
->
Tuple
[
Tensor
,
Tensor
]:
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
r
"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
# which will add high freq artifacts.
resampling as well as the indices in which they are valid. LinearResample (LR) means
base_freq
*=
0.99
that the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
the signal.
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
The reason why the same filter is not used for multiple convolutions is because the
# We can then sample the function x(t) with a different sample rate:
sinc function could sampled at different points in time. For example, suppose
# y[j] = x(j / new_freq)
a signal is sampled at the timestamps (seconds)
# or,
0 16 32
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
and we want it to be sampled at the timestamps (seconds)
0 5 10 15 20 25 30 35
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
at the timestamp of 16, the delta timestamps are
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
16 11 6 1 4 9 14 19
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
at the timestamp of 32, the delta timestamps are
# Indeed:
32 27 22 17 12 8 2 3
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
As we can see from deltas, the sinc function is sampled at different points of time
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
for 16 vs [...., 2, 3, ....] for 32)
# This will explain the F.conv1d after, with a stride of orig_freq.
width
=
math
.
ceil
(
lowpass_filter_width
*
orig_freq
/
base_freq
)
Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
there needs to be one filter.
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
# future work.
has infinite support (non-zero for all values) so instead it is truncated and multiplied by
idx
=
torch
.
arange
(
-
width
,
width
+
orig_freq
,
device
=
device
,
dtype
=
dtype
)
a window function which gives it less-than-perfect rolloff [1].
for
i
in
range
(
new_freq
):
[1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm
t
=
(
-
i
/
new_freq
+
idx
/
orig_freq
)
*
base_freq
t
=
t
.
clamp_
(
-
lowpass_filter_width
,
lowpass_filter_width
)
Args:
t
*=
math
.
pi
orig_freq (float): The original frequency of the signal
# we do not use torch.hann_window here as we need to evaluate the window
new_freq (float): The desired frequency
# at spectifics positions, not over a regular grid.
output_samples_in_unit (int): The number of output samples in the smallest repeating unit:
window
=
torch
.
cos
(
t
/
lowpass_filter_width
/
2
)
**
2
num_samp_out = new_freq / Gcd(orig_freq, new_freq)
kernel
=
torch
.
where
(
t
==
0
,
torch
.
tensor
(
1.
).
to
(
t
),
torch
.
sin
(
t
)
/
t
)
window_width (float): The width of the window which is nonzero
kernel
.
mul_
(
window
)
lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less
kernels
.
append
(
kernel
)
than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less
scale
=
base_freq
/
orig_freq
efficient. We suggest around 4 to 10 for normal use
return
torch
.
stack
(
kernels
).
view
(
new_freq
,
1
,
-
1
).
mul_
(
scale
),
width
Returns:
(Tensor, Tensor): A tuple of ``min_input_index`` (which is the minimum indices
where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert
lowpass_cutoff
<
min
(
orig_freq
,
new_freq
)
/
2
output_t
=
torch
.
arange
(
0.
,
output_samples_in_unit
,
device
=
device
,
dtype
=
dtype
)
/
new_freq
min_t
=
output_t
-
window_width
max_t
=
output_t
+
window_width
min_input_index
=
torch
.
ceil
(
min_t
*
orig_freq
)
# size (output_samples_in_unit)
max_input_index
=
torch
.
floor
(
max_t
*
orig_freq
)
# size (output_samples_in_unit)
num_indices
=
max_input_index
-
min_input_index
+
1
# size (output_samples_in_unit)
max_weight_width
=
num_indices
.
max
()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j
=
torch
.
arange
(
max_weight_width
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
0
)
input_index
=
min_input_index
.
unsqueeze
(
1
)
+
j
delta_t
=
(
input_index
/
orig_freq
)
-
output_t
.
unsqueeze
(
1
)
weights
=
torch
.
zeros_like
(
delta_t
)
inside_window_indices
=
delta_t
.
abs
().
lt
(
window_width
)
# raised-cosine (Hanning) window with width `window_width`
weights
[
inside_window_indices
]
=
0.5
*
(
1
+
torch
.
cos
(
2
*
math
.
pi
*
lowpass_cutoff
/
lowpass_filter_width
*
delta_t
[
inside_window_indices
]))
t_eq_zero_indices
=
delta_t
.
eq
(
0.0
)
t_not_eq_zero_indices
=
~
t_eq_zero_indices
# sinc filter function
weights
[
t_not_eq_zero_indices
]
*=
torch
.
sin
(
2
*
math
.
pi
*
lowpass_cutoff
*
delta_t
[
t_not_eq_zero_indices
])
/
(
math
.
pi
*
delta_t
[
t_not_eq_zero_indices
])
# limit of the function at t = 0
weights
[
t_eq_zero_indices
]
*=
2
*
lowpass_cutoff
weights
/=
orig_freq
# size (output_samples_in_unit, max_weight_width)
return
min_input_index
,
weights
def
_lcm
(
a
:
int
,
b
:
int
)
->
int
:
return
abs
(
a
*
b
)
//
math
.
gcd
(
a
,
b
)
def
_get_num_LR_output_samples
(
input_num_samp
:
int
,
samp_rate_in
:
float
,
samp_rate_out
:
float
)
->
int
:
r
"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
Args:
input_num_samp (int): The number of samples in the input
samp_rate_in (float): The original frequency of the signal
samp_rate_out (float): The desired frequency
Returns:
int: The number of output samples
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_rate_in and
# samp_rate_out.
samp_rate_in
=
int
(
samp_rate_in
)
samp_rate_out
=
int
(
samp_rate_out
)
tick_freq
=
_lcm
(
samp_rate_in
,
samp_rate_out
)
ticks_per_input_period
=
tick_freq
//
samp_rate_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_rate_in ).
interval_length_in_ticks
=
input_num_samp
*
ticks_per_input_period
if
interval_length_in_ticks
<=
0
:
return
0
ticks_per_output_period
=
tick_freq
//
samp_rate_out
# Get the last output-sample in the closed interval, i.e. replacing [ ) with
# [ ]. Note: integer division rounds down. See
# http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
# the notation.
last_output_samp
=
interval_length_in_ticks
//
ticks_per_output_period
# We need the last output-sample in the open interval, so if it takes us to
# the end of the interval exactly, subtract one.
if
last_output_samp
*
ticks_per_output_period
==
interval_length_in_ticks
:
last_output_samp
-=
1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp
=
last_output_samp
+
1
return
num_output_samp
def
resample_waveform
(
waveform
:
Tensor
,
def
resample_waveform
(
waveform
:
Tensor
,
...
@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
...
@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
Returns:
Returns:
Tensor: The waveform at the new frequency
Tensor: The waveform at the new frequency
"""
"""
device
,
dtype
=
waveform
.
device
,
waveform
.
dtype
assert
waveform
.
dim
()
==
2
assert
waveform
.
dim
()
==
2
assert
orig_freq
>
0.0
and
new_freq
>
0.0
assert
orig_freq
>
0.0
and
new_freq
>
0.0
min_freq
=
min
(
orig_freq
,
new_freq
)
orig_freq
=
int
(
orig_freq
)
lowpass_cutoff
=
0.99
*
0.5
*
min_freq
new_freq
=
int
(
new_freq
)
gcd
=
math
.
gcd
(
orig_freq
,
new_freq
)
assert
lowpass_cutoff
*
2
<=
min_freq
orig_freq
=
orig_freq
//
gcd
new_freq
=
new_freq
//
gcd
base_freq
=
math
.
gcd
(
int
(
orig_freq
),
int
(
new_freq
))
input_samples_in_unit
=
int
(
orig_freq
)
//
base_freq
kernel
,
width
=
_get_sinc_resample_kernel
(
orig_freq
,
new_freq
,
lowpass_filter_width
,
output_samples_in_unit
=
int
(
new_freq
)
//
base_freq
waveform
.
device
,
waveform
.
dtype
)
window_width
=
lowpass_filter_width
/
(
2.0
*
lowpass_cutoff
)
num_wavs
,
length
=
waveform
.
shape
first_indices
,
weights
=
_get_LR_indices_and_weights
(
waveform
=
F
.
pad
(
waveform
,
(
width
,
width
+
orig_freq
))
orig_freq
,
new_freq
,
output_samples_in_unit
,
resampled
=
F
.
conv1d
(
waveform
[:,
None
],
kernel
,
stride
=
orig_freq
)
window_width
,
lowpass_cutoff
,
lowpass_filter_width
,
device
,
dtype
)
resampled
=
resampled
.
transpose
(
1
,
2
).
reshape
(
num_wavs
,
-
1
)
target_length
=
int
(
math
.
ceil
(
new_freq
*
length
/
orig_freq
))
assert
first_indices
.
dim
()
==
1
return
resampled
[...,
:
target_length
]
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
# all the weights have the same stride but have different padding.
# Current implementation takes the input and applies the various padding before
# doing a conv1d for that specific weight.
conv_stride
=
input_samples_in_unit
conv_transpose_stride
=
output_samples_in_unit
num_channels
,
wave_len
=
waveform
.
size
()
window_size
=
weights
.
size
(
1
)
tot_output_samp
=
_get_num_LR_output_samples
(
wave_len
,
orig_freq
,
new_freq
)
output
=
torch
.
zeros
((
num_channels
,
tot_output_samp
),
device
=
device
,
dtype
=
dtype
)
# eye size: (num_channels, num_channels, 1)
eye
=
torch
.
eye
(
num_channels
,
device
=
device
,
dtype
=
dtype
).
unsqueeze
(
2
)
for
i
in
range
(
first_indices
.
size
(
0
)):
wave_to_conv
=
waveform
first_index
=
int
(
first_indices
[
i
].
item
())
if
first_index
>=
0
:
# trim the signal as the filter will not be applied before the first_index
wave_to_conv
=
wave_to_conv
[...,
first_index
:]
# pad the right of the signal to allow partial convolutions meaning compute
# values for partial windows (e.g. end of the window is outside the signal length)
max_unit_index
=
(
tot_output_samp
-
1
)
//
output_samples_in_unit
end_index_of_last_window
=
max_unit_index
*
conv_stride
+
window_size
current_wave_len
=
wave_len
-
first_index
right_padding
=
max
(
0
,
end_index_of_last_window
+
1
-
current_wave_len
)
left_padding
=
max
(
0
,
-
first_index
)
if
left_padding
!=
0
or
right_padding
!=
0
:
wave_to_conv
=
torch
.
nn
.
functional
.
pad
(
wave_to_conv
,
(
left_padding
,
right_padding
))
conv_wave
=
torch
.
nn
.
functional
.
conv1d
(
wave_to_conv
.
unsqueeze
(
0
),
weights
[
i
].
repeat
(
num_channels
,
1
,
1
),
stride
=
conv_stride
,
groups
=
num_channels
)
# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave
=
torch
.
nn
.
functional
.
conv_transpose1d
(
conv_wave
,
eye
,
stride
=
conv_transpose_stride
).
squeeze
(
0
)
# pad dilated_conv_wave so it reaches the output length if needed.
dialated_conv_wave_len
=
dilated_conv_wave
.
size
(
-
1
)
left_padding
=
i
right_padding
=
max
(
0
,
tot_output_samp
-
(
left_padding
+
dialated_conv_wave_len
))
dilated_conv_wave
=
torch
.
nn
.
functional
.
pad
(
dilated_conv_wave
,
(
left_padding
,
right_padding
))[...,
:
tot_output_samp
]
output
+=
dilated_conv_wave
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment