Commit 616663ff authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Kaldi Resample (#134)

parent 0902494e
......@@ -47,8 +47,9 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
class Test_Kaldi(unittest.TestCase):
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
test_8000_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file_8000.wav')
kaldi_output_dir = os.path.join(test_dirpath, 'assets', 'kaldi')
test_filepaths = {'spec': [], 'fbank': []}
test_filepaths = {prefix: [] for prefix in test.compliance.utils.TEST_PREFIX}
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for f in os.listdir(kaldi_output_dir):
......@@ -118,16 +119,20 @@ class Test_Kaldi(unittest.TestCase):
print('abs_mse:', abs_mse.item(), 'abs_max_error:', abs_max_error.item())
print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())
def _compliance_test_helper(self, filepath_key, expected_num_files, expected_num_args, get_output_fn):
def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files,
expected_num_args, get_output_fn, atol=1e-5, rtol=1e-8):
"""
Inputs:
sound_filepath (str): The location of the sound file
filepath_key (str): A key to `test_filepaths` which matches which files to use
expected_num_files (int): The expected number of kaldi files to read
expected_num_args (int): The expected number of arguments used in a kaldi configuration
get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
and a configuration and returns an output
atol (float): absolute tolerance
rtol (float): relative tolerance
"""
sound, sample_rate = torchaudio.load_wav(self.test_filepath)
sound, sample_rate = torchaudio.load_wav(sound_filepath)
files = self.test_filepaths[filepath_key]
assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
......@@ -152,7 +157,7 @@ class Test_Kaldi(unittest.TestCase):
self._print_diagnostic(output, kaldi_output)
self.assertTrue(output.shape, kaldi_output.shape)
self.assertTrue(torch.allclose(output, kaldi_output, atol=1e-3, rtol=1e-1))
self.assertTrue(torch.allclose(output, kaldi_output, atol=atol, rtol=rtol))
def test_spectrogram(self):
def get_output_fn(sound, args):
......@@ -172,7 +177,7 @@ class Test_Kaldi(unittest.TestCase):
window_type=args[12])
return output
self._compliance_test_helper('spec', 131, 13, get_output_fn)
self._compliance_test_helper(self.test_filepath, 'spec', 131, 13, get_output_fn, atol=1e-3, rtol=0)
def test_fbank(self):
def get_output_fn(sound, args):
......@@ -202,7 +207,65 @@ class Test_Kaldi(unittest.TestCase):
window_type=args[21])
return output
self._compliance_test_helper('fbank', 97, 22, get_output_fn)
self._compliance_test_helper(self.test_filepath, 'fbank', 97, 22, get_output_fn, atol=1e-3, rtol=1e-1)
def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
return output
self._compliance_test_helper(self.test_8000_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)
def test_resample_waveform_upsample_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
upsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate * 2)
self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)
def test_resample_waveform_downsample_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate // 2)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2)
def test_resample_waveform_identity_size(self):
sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
downsample_sound = kaldi.resample_waveform(sound, sample_rate, sample_rate)
self.assertTrue(downsample_sound.size(-1) == sound.size(-1))
def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
atol=1e-1, rtol=1e-4):
# resample the signal and compare it to the ground truth
n_to_trim = 20
sample_rate = 1000
new_sample_rate = sample_rate
if up_scale_factor is not None:
new_sample_rate *= up_scale_factor
if down_scale_factor is not None:
new_sample_rate //= down_scale_factor
duration = 5 # seconds
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)
sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()
new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
# trim the first/last n samples as these points have boundary effects
ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
estimate = estimate[..., n_to_trim:-n_to_trim]
self.assertTrue(torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol))
def test_resample_waveform_downsample_accuracy(self):
for i in range(1, 20):
self._test_resample_waveform_accuracy(down_scale_factor=i * 2)
def test_resample_waveform_upsample_accuracy(self):
for i in range(1, 20):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0)
if __name__ == '__main__':
......
import random
import torchaudio
TEST_PREFIX = ['fbank', 'spec']
TEST_PREFIX = ['fbank', 'spec', 'resample']
def generate_rand_boolean():
......
......@@ -306,6 +306,29 @@ class Tester(unittest.TestCase):
_test_librosa_consistency_helper(**kwargs2)
_test_librosa_consistency_helper(**kwargs3)
def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')
self.assertRaises(ValueError, invalid_resample, sound)
upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation')
up_sampled = upsample_resample(sound)
# we expect the upsampled signal to have twice as many samples
self.assertTrue(up_sampled.size(-1) == sound.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method='sinc_interpolation')
down_sampled = downsample_resample(sound)
# we expect the downsampled signal to have half as many samples
self.assertTrue(down_sampled.size(-1) == sound.size(-1) // 2)
if __name__ == '__main__':
unittest.main()
......@@ -13,6 +13,7 @@ __all__ = [
'spectrogram',
'vtln_warp_freq',
'vtln_warp_mel_freq',
'resample_waveform',
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
......@@ -512,3 +513,217 @@ def fbank(
mel_energies = mel_energies - col_means
return mel_energies
def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
lowpass_cutoff, lowpass_filter_width):
"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of `new_freq`). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
The reason why the same filter is not used for multiple convolutions is because the
sinc function could sampled at different points in time. For example, suppose
a signal is sampled at the timestamps (seconds)
0 16 32
and we want it to be sampled at the timestamps (seconds)
0 5 10 15 20 25 30 35
at the timestamp of 16, the delta timestamps are
16 11 6 1 4 9 14 19
at the timestamp of 32, the delta timestamps are
32 27 22 17 12 8 2 3
As we can see from deltas, the sinc function is sampled at different points of time
assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....]
for 16 vs [...., 2, 3, ....] for 32)
Example, one case is when the orig_freq and new_freq are multiples of each other then
there needs to be one filter.
A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function
has infinite support (non-zero for all values) so instead it is truncated and multiplied by
a window function which gives it less-than-perfect rolloff [1].
[1] Chapter 16: Windowed-Sinc Filters, https://www.dspguide.com/ch16/1.htm
Args:
orig_freq (float): the original frequency of the signal
new_freq (float): the desired frequency
output_samples_in_unit (int): the number of output samples in the smallest repeating unit:
num_samp_out = new_freq / Gcd(orig_freq, new_freq)
window_width (float): the width of the window which is nonzero
lowpass_cutoff (float): the filter cutoff in Hz. The filter cutoff needs to be less
than samp_rate_in_hz/2 and less than samp_rate_out_hz/2.
lowpass_filter_width (int): controls the sharpness of the filter, more == sharper but less
efficient. We suggest around 4 to 10 for normal use
Returns:
min_input_index (Tensor): the minimum indices where the window is valid. size (output_samples_in_unit)
weights (Tensor): the weights which correspond with min_input_index. size (
output_samples_in_unit, max_weight_width)
"""
assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0, output_samples_in_unit, dtype=torch.get_default_dtype()) / new_freq
min_t = output_t - window_width
max_t = output_t + window_width
min_input_index = torch.ceil(min_t * orig_freq) # size (output_samples_in_unit)
max_input_index = torch.floor(max_t * orig_freq) # size (output_samples_in_unit)
num_indices = max_input_index - min_input_index + 1 # size (output_samples_in_unit)
max_weight_width = num_indices.max()
# create a group of weights of size (output_samples_in_unit, max_weight_width)
j = torch.arange(max_weight_width).unsqueeze(0)
input_index = min_input_index.unsqueeze(1) + j
delta_t = (input_index / orig_freq) - output_t.unsqueeze(1)
weights = torch.zeros_like(delta_t)
inside_window_indices = delta_t.abs().lt(window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (1 + torch.cos(2 * math.pi * lowpass_cutoff /
lowpass_filter_width * delta_t[inside_window_indices]))
t_eq_zero_indices = delta_t.eq(0.0)
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= torch.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
weights /= orig_freq # size (output_samples_in_unit, max_weight_width)
return min_input_index, weights
def _lcm(a, b):
return abs(a * b) // math.gcd(a, b)
def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out):
""" Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of `new_freq`). It uses sinc/bandlimited interpolation to upsample/downsample
the signal.
Args:
input_num_samp (int): the number of samples in the input
samp_rate_in (float): the original frequency of the signal
samp_rate_out (float): the desired frequency
Returns:
int: the number of output samples
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_rate_in and
# samp_rate_out.
samp_rate_in = int(samp_rate_in)
samp_rate_out = int(samp_rate_out)
tick_freq = _lcm(samp_rate_in, samp_rate_out)
ticks_per_input_period = tick_freq // samp_rate_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_rate_in ).
interval_length_in_ticks = input_num_samp * ticks_per_input_period
if interval_length_in_ticks <= 0:
return 0
ticks_per_output_period = tick_freq // samp_rate_out
# Get the last output-sample in the closed interval, i.e. replacing [ ) with
# [ ]. Note: integer division rounds down. See
# http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
# the notation.
last_output_samp = interval_length_in_ticks // ticks_per_output_period
# We need the last output-sample in the open interval, so if it takes us to
# the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length_in_ticks:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def resample_waveform(wave, orig_freq, new_freq, lowpass_filter_width=6):
r"""Resamples the wave at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
the output signal has a frequency of `new_freq`). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
Args:
wave (Tensor): the input signal of size (c, n)
orig_freq (float): the original frequency of the signal
new_freq (float): the desired frequency
lowpass_filter_width (int): controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use
Returns:
Tensor: the signal at the new frequency
"""
assert wave.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0
min_freq = min(orig_freq, new_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq
assert lowpass_cutoff * 2 <= min_freq
base_freq = math.gcd(int(orig_freq), int(new_freq))
input_samples_in_unit = int(orig_freq) // base_freq
output_samples_in_unit = int(new_freq) // base_freq
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width)
assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
# all the weights have the same stride but have different padding.
# Current implementation takes the input and applies the various padding before
# doing a conv1d for that specific weight.
conv_stride = input_samples_in_unit
conv_transpose_stride = output_samples_in_unit
num_channels, wave_len = wave.size()
window_size = weights.size(1)
tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq)
output = torch.zeros((num_channels, tot_output_samp))
eye = torch.eye(num_channels).unsqueeze(2) # size (num_channels, num_channels, 1)
for i in range(first_indices.size(0)):
wave_to_conv = wave
first_index = int(first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied before the first_index
wave_to_conv = wave_to_conv[..., first_index:]
# pad the right of the signal to allow partial convolutions meaning compute
# values for partial windows (e.g. end of the window is outside the signal length)
max_unit_index = (tot_output_samp - 1) // output_samples_in_unit
end_index_of_last_window = max_unit_index * conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index_of_last_window + 1 - current_wave_len)
left_padding = max(0, -first_index)
if left_padding != 0 or right_padding != 0:
wave_to_conv = torch.nn.functional.pad(wave_to_conv, (left_padding, right_padding))
conv_wave = torch.nn.functional.conv1d(
wave_to_conv.unsqueeze(0), weights[i].view(1, 1, window_size), stride=conv_stride)
# we want conv_wave[:, i] to be at output[:, i + n*conv_transpose_stride]
dilated_conv_wave = torch.nn.functional.conv_transpose1d(
conv_wave, eye, stride=conv_transpose_stride).squeeze(0)
# pad dilated_conv_wave so it reaches the output length if needed.
dialated_conv_wave_len = dilated_conv_wave.size(-1)
left_padding = i
right_padding = max(0, tot_output_samp - (left_padding + dialated_conv_wave_len))
dilated_conv_wave = torch.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding))[..., :tot_output_samp]
output += dilated_conv_wave
return output
......@@ -4,6 +4,7 @@ import math
import torch
from typing import Optional
from . import functional as F
from .compliance import kaldi
# TODO remove this class
......@@ -497,3 +498,33 @@ class MuLawExpanding(torch.jit.ScriptModule):
def __repr__(self):
return self.__class__.__name__ + '()'
class Resample(torch.nn.Module):
"""Resamples a signal from one frequency to another. A resampling method can
be given.
Args:
orig_freq (float): the original frequency of the signal
new_freq (float): the desired frequency
resampling_method (str): the resampling method (Default: 'kaldi' which uses
sinc interpolation)
"""
def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'):
super(Resample, self).__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.resampling_method = resampling_method
def forward(self, sig):
"""
Args:
sig (Tensor): the input signal of size (c, n)
Returns:
Tensor: output signal of size (c, m)
"""
if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(sig, self.orig_freq, self.new_freq)
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment