Unverified Commit e43a8e76 authored by Alexandre Défossez's avatar Alexandre Défossez Committed by GitHub
Browse files

Make resampling simpler and faster (#1087)

parent f1d8d1e0
...@@ -3,6 +3,7 @@ from typing import Tuple ...@@ -3,6 +3,7 @@ from typing import Tuple
import math import math
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import functional as F
import torchaudio import torchaudio
import torchaudio._internal.fft import torchaudio._internal.fft
...@@ -752,141 +753,54 @@ def mfcc( ...@@ -752,141 +753,54 @@ def mfcc(
return feature return feature
def _get_LR_indices_and_weights(orig_freq: float, def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
new_freq: float, device: torch.device, dtype: torch.dtype):
output_samples_in_unit: int, assert lowpass_filter_width > 0
window_width: float, kernels = []
lowpass_cutoff: float, base_freq = min(orig_freq, new_freq)
lowpass_filter_width: int, # This will perform antialiasing filtering by removing the highest frequencies.
device: torch.device, # At first I thought I only needed this when downsampling, but when upsampling
dtype: int) -> Tuple[Tensor, Tensor]: # you will get edge artifacts without this, as the edge is equivalent to zero padding,
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for # which will add high freq artifacts.
resampling as well as the indices in which they are valid. LinearResample (LR) means base_freq *= 0.99
that the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
the signal. # using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
The reason why the same filter is not used for multiple convolutions is because the # We can then sample the function x(t) with a different sample rate:
sinc function could sampled at different points in time. For example, suppose # y[j] = x(j / new_freq)
a signal is sampled at the timestamps (seconds) # or,
0 16 32 # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
and we want it to be sampled at the timestamps (seconds)
0 5 10 15 20 25 30 35 # We see here that y[j] is the convolution of x[i] with a specific filter, for which
at the timestamp of 16, the delta timestamps are # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
16 11 6 1 4 9 14 19 # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
at the timestamp of 32, the delta timestamps are # Indeed:
32 27 22 17 12 8 2 3 # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
As we can see from deltas, the sinc function is sampled at different points of time # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....] # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
for 16 vs [...., 2, 3, ....] for 32) # This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
there needs to be one filter. # they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function # future work.
has infinite support (non-zero for all values) so instead it is truncated and multiplied by idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
a window function which gives it less-than-perfect rolloff [1].
for i in range(new_freq):
[1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
Args: t *= math.pi
orig_freq (float): The original frequency of the signal # we do not use torch.hann_window here as we need to evaluate the window
new_freq (float): The desired frequency # at spectifics positions, not over a regular grid.
output_samples_in_unit (int): The number of output samples in the smallest repeating unit: window = torch.cos(t / lowpass_filter_width / 2)**2
num_samp_out = new_freq / Gcd(orig_freq, new_freq) kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
window_width (float): The width of the window which is nonzero kernel.mul_(window)
lowpass_cutoff (float): The filter cutoff in Hz. The filter cutoff needs to be less kernels.append(kernel)
than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper but less scale = base_freq / orig_freq
efficient. We suggest around 4 to 10 for normal use return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
Returns:
(Tensor, Tensor): A tuple of ``min_input_index`` (which is the minimum indices
where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0., output_samples_in_unit, device=device, dtype=dtype) / new_freq
min_t = output_t - window_width
max_t = output_t + window_width
min_input_index = torch.ceil(min_t * orig_freq) # size (output_samples_in_unit)
max_input_index = torch.floor(max_t * orig_freq) # size (output_samples_in_unit)
num_indices = max_input_index - min_input_index + 1 # size (output_samples_in_unit)
max_weight_width = num_indices.max()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j = torch.arange(max_weight_width, device=device, dtype=dtype).unsqueeze(0)
input_index = min_input_index.unsqueeze(1) + j
delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)
weights = torch.zeros_like(delta_t)
inside_window_indices = delta_t.abs().lt(window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (1 + torch.cos(2 * math.pi * lowpass_cutoff /
lowpass_filter_width * delta_t[inside_window_indices]))
t_eq_zero_indices = delta_t.eq(0.0)
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= torch.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
weights /= orig_freq # size (output_samples_in_unit, max_weight_width)
return min_input_index, weights
def _lcm(a: int, b: int) -> int:
return abs(a * b) // math.gcd(a, b)
def _get_num_LR_output_samples(input_num_samp: int,
samp_rate_in: float,
samp_rate_out: float) -> int:
r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
Args:
input_num_samp (int): The number of samples in the input
samp_rate_in (float): The original frequency of the signal
samp_rate_out (float): The desired frequency
Returns:
int: The number of output samples
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_rate_in and
# samp_rate_out.
samp_rate_in = int(samp_rate_in)
samp_rate_out = int(samp_rate_out)
tick_freq = _lcm(samp_rate_in, samp_rate_out)
ticks_per_input_period = tick_freq // samp_rate_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_rate_in ).
interval_length_in_ticks = input_num_samp * ticks_per_input_period
if interval_length_in_ticks <= 0:
return 0
ticks_per_output_period = tick_freq // samp_rate_out
# Get the last output-sample in the closed interval, i.e. replacing [ ) with
# [ ]. Note: integer division rounds down. See
# http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
# the notation.
last_output_samp = interval_length_in_ticks // ticks_per_output_period
# We need the last output-sample in the open interval, so if it takes us to
# the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length_in_ticks:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def resample_waveform(waveform: Tensor, def resample_waveform(waveform: Tensor,
...@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor, ...@@ -912,72 +826,21 @@ def resample_waveform(waveform: Tensor,
Returns: Returns:
Tensor: The waveform at the new frequency Tensor: The waveform at the new frequency
""" """
device, dtype = waveform.device, waveform.dtype
assert waveform.dim() == 2 assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0 assert orig_freq > 0.0 and new_freq > 0.0
min_freq = min(orig_freq, new_freq) orig_freq = int(orig_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
assert lowpass_cutoff * 2 <= min_freq orig_freq = orig_freq // gcd
new_freq = new_freq // gcd
base_freq = math.gcd(int(orig_freq), int(new_freq))
input_samples_in_unit = int(orig_freq) // base_freq kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
output_samples_in_unit = int(new_freq) // base_freq waveform.device, waveform.dtype)
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff) num_wavs, length = waveform.shape
first_indices, weights = _get_LR_indices_and_weights( waveform = F.pad(waveform, (width, width + orig_freq))
orig_freq, new_freq, output_samples_in_unit, resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
window_width, lowpass_cutoff, lowpass_filter_width, device, dtype) resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
assert first_indices.dim() == 1 return resampled[..., :target_length]
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
# all the weights have the same stride but have different padding.
# Current implementation takes the input and applies the various padding before
# doing a conv1d for that specific weight.
conv_stride = input_samples_in_unit
conv_transpose_stride = output_samples_in_unit
num_channels, wave_len = waveform.size()
window_size = weights.size(1)
tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq)
output = torch.zeros((num_channels, tot_output_samp),
device=device, dtype=dtype)
# eye size: (num_channels, num_channels, 1)
eye = torch.eye(num_channels, device=device, dtype=dtype).unsqueeze(2)
for i in range(first_indices.size(0)):
wave_to_conv = waveform
first_index = int(first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied before the first_index
wave_to_conv = wave_to_conv[..., first_index:]
# pad the right of the signal to allow partial convolutions meaning compute
# values for partial windows (e.g. end of the window is outside the signal length)
max_unit_index = (tot_output_samp - 1) // output_samples_in_unit
end_index_of_last_window = max_unit_index * conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index_of_last_window + 1 - current_wave_len)
left_padding = max(0, -first_index)
if left_padding != 0 or right_padding != 0:
wave_to_conv = torch.nn.functional.pad(wave_to_conv, (left_padding, right_padding))
conv_wave = torch.nn.functional.conv1d(
wave_to_conv.unsqueeze(0), weights[i].repeat(num_channels, 1, 1),
stride=conv_stride, groups=num_channels)
# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave = torch.nn.functional.conv_transpose1d(
conv_wave, eye, stride=conv_transpose_stride).squeeze(0)
# pad dilated_conv_wave so it reaches the output length if needed.
dialated_conv_wave_len = dilated_conv_wave.size(-1)
left_padding = i
right_padding = max(0, tot_output_samp - (left_padding + dialated_conv_wave_len))
dilated_conv_wave = torch.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding))[..., :tot_output_samp]
output += dilated_conv_wave
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment