"docs/vscode:/vscode.git/clone" did not exist on "cb40dd72e0d7ae95f2f558d4f9818c3cf85914a2"
Commit 9409824f authored by Oktai Tatanov's avatar Oktai Tatanov Committed by Vincent QB
Browse files

JIT resample waveform (#362)



* test with jit.

* test passed after adding annotation, and removing get_default_dtype

* fix conversion error.

* moving test to transform.

* reverting to original test.

* move type.

* math.gcd added in python 3.5.
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 4a934693
......@@ -312,6 +312,13 @@ class Tester(unittest.TestCase):
_test_librosa_consistency_helper(**kwargs2)
_test_librosa_consistency_helper(**kwargs3)
def test_scriptmodule_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100
sample_rate_2 = 50
_test_script_module(transforms.Spectrogram, tensor, sample_rate, sample_rate_2)
def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path)
......
from __future__ import absolute_import, division, print_function, unicode_literals
import math
import fractions
import random
import torch
import torchaudio
......@@ -20,7 +19,7 @@ __all__ = [
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps, dtype=torch.get_default_dtype())
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001
......@@ -33,6 +32,22 @@ BLACKMAN = 'blackman'
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def gcd(a, b):
# type: (int, int) -> int
"""Calculate the Greatest Common Divisor of a and b.
Unless b==0, the result will have the same sign as b (so that when
b is divided by it, the result comes out positive).
"""
try:
return math.gcd(a, b)
except AttributeError:
while b:
a, b = b, a % b
return a
def _next_power_of_2(x):
r"""Returns the smallest power of 2 that is greater than x
"""
......@@ -92,10 +107,10 @@ def _feature_window_function(window_type, window_size, blackman_coeff):
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size, dtype=torch.get_default_dtype())
return torch.ones(window_size)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, dtype=torch.get_default_dtype())
window_function = torch.arange(window_size)
# can't use torch.blackman_window as they use different coefficients
return blackman_coeff - 0.5 * torch.cos(a * window_function) + \
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
......@@ -111,7 +126,7 @@ def _get_log_energy(strided_input, epsilon, energy_floor):
return log_energy
else:
return torch.max(log_energy,
torch.tensor(math.log(energy_floor), dtype=torch.get_default_dtype()))
torch.tensor(math.log(energy_floor)))
def _get_waveform_and_window_properties(waveform, channel, sample_frequency, frame_shift,
......@@ -397,7 +412,7 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq,
('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' %
(vtln_low, vtln_high, low_freq, high_freq))
bin = torch.arange(num_bins, dtype=torch.get_default_dtype()).unsqueeze(1)
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
......@@ -409,7 +424,7 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq,
center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins, dtype=torch.get_default_dtype())).unsqueeze(0)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
......@@ -543,7 +558,7 @@ def _get_lifter_coeffs(num_ceps, cepstral_lifter):
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps, dtype=torch.get_default_dtype())
i = torch.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
......@@ -652,6 +667,7 @@ def mfcc(
def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
lowpass_cutoff, lowpass_filter_width):
# type: (float, float, int, float, float, int) -> Tuple[Tensor, Tensor]
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a
......@@ -699,7 +715,7 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
"""
assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0, output_samples_in_unit, dtype=torch.get_default_dtype()) / new_freq
output_t = torch.arange(0., output_samples_in_unit) / new_freq
min_t = output_t - window_width
max_t = output_t + window_width
......@@ -732,10 +748,12 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win
def _lcm(a, b):
return abs(a * b) // fractions.gcd(a, b)
# type: (int, int) -> int
return abs(a * b) // gcd(a, b)
def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out):
# type: (int, float, float) -> int
r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
......@@ -780,6 +798,7 @@ def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out):
def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6):
# type: (Tensor, float, float, int) -> Tensor
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
......@@ -807,7 +826,7 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6):
assert lowpass_cutoff * 2 <= min_freq
base_freq = fractions.gcd(int(orig_freq), int(new_freq))
base_freq = gcd(int(orig_freq), int(new_freq))
input_samples_in_unit = int(orig_freq) // base_freq
output_samples_in_unit = int(new_freq) // base_freq
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment