Commit 9409824f authored by Oktai Tatanov's avatar Oktai Tatanov Committed by Vincent QB
Browse files

JIT resample waveform (#362)



* test with jit.

* test passed after adding annotation, and removing get_default_dtype

* fix conversion error.

* moving test to transform.

* reverting to original test.

* move type.

* math.gcd added in python 3.5.
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent 4a934693
...@@ -312,6 +312,13 @@ class Tester(unittest.TestCase): ...@@ -312,6 +312,13 @@ class Tester(unittest.TestCase):
_test_librosa_consistency_helper(**kwargs2) _test_librosa_consistency_helper(**kwargs2)
_test_librosa_consistency_helper(**kwargs3) _test_librosa_consistency_helper(**kwargs3)
def test_scriptmodule_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100
sample_rate_2 = 50
_test_script_module(transforms.Spectrogram, tensor, sample_rate, sample_rate_2)
def test_resample_size(self): def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path) waveform, sample_rate = torchaudio.load(input_path)
......
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import math import math
import fractions
import random import random
import torch import torch
import torchaudio import torchaudio
...@@ -20,7 +19,7 @@ __all__ = [ ...@@ -20,7 +19,7 @@ __all__ = [
] ]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07 # numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps, dtype=torch.get_default_dtype()) EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds # 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001 MILLISECONDS_TO_SECONDS = 0.001
...@@ -33,6 +32,22 @@ BLACKMAN = 'blackman' ...@@ -33,6 +32,22 @@ BLACKMAN = 'blackman'
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def gcd(a, b):
# type: (int, int) -> int
"""Calculate the Greatest Common Divisor of a and b.
Unless b==0, the result will have the same sign as b (so that when
b is divided by it, the result comes out positive).
"""
try:
return math.gcd(a, b)
except AttributeError:
while b:
a, b = b, a % b
return a
def _next_power_of_2(x): def _next_power_of_2(x):
r"""Returns the smallest power of 2 that is greater than x r"""Returns the smallest power of 2 that is greater than x
""" """
...@@ -92,10 +107,10 @@ def _feature_window_function(window_type, window_size, blackman_coeff): ...@@ -92,10 +107,10 @@ def _feature_window_function(window_type, window_size, blackman_coeff):
# like hanning but goes to zero at edges # like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False).pow(0.85) return torch.hann_window(window_size, periodic=False).pow(0.85)
elif window_type == RECTANGULAR: elif window_type == RECTANGULAR:
return torch.ones(window_size, dtype=torch.get_default_dtype()) return torch.ones(window_size)
elif window_type == BLACKMAN: elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1) a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, dtype=torch.get_default_dtype()) window_function = torch.arange(window_size)
# can't use torch.blackman_window as they use different coefficients # can't use torch.blackman_window as they use different coefficients
return blackman_coeff - 0.5 * torch.cos(a * window_function) + \ return blackman_coeff - 0.5 * torch.cos(a * window_function) + \
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function) (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
...@@ -111,7 +126,7 @@ def _get_log_energy(strided_input, epsilon, energy_floor): ...@@ -111,7 +126,7 @@ def _get_log_energy(strided_input, epsilon, energy_floor):
return log_energy return log_energy
else: else:
return torch.max(log_energy, return torch.max(log_energy,
torch.tensor(math.log(energy_floor), dtype=torch.get_default_dtype())) torch.tensor(math.log(energy_floor)))
def _get_waveform_and_window_properties(waveform, channel, sample_frequency, frame_shift, def _get_waveform_and_window_properties(waveform, channel, sample_frequency, frame_shift,
...@@ -397,7 +412,7 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq, ...@@ -397,7 +412,7 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq,
('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' % ('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' %
(vtln_low, vtln_high, low_freq, high_freq)) (vtln_low, vtln_high, low_freq, high_freq))
bin = torch.arange(num_bins, dtype=torch.get_default_dtype()).unsqueeze(1) bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1) left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1) center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1) right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
...@@ -409,7 +424,7 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq, ...@@ -409,7 +424,7 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq,
center_freqs = inverse_mel_scale(center_mel) # size (num_bins) center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins) # size(1, num_fft_bins)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins, dtype=torch.get_default_dtype())).unsqueeze(0) mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins) # size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel) up_slope = (mel - left_mel) / (center_mel - left_mel)
...@@ -543,7 +558,7 @@ def _get_lifter_coeffs(num_ceps, cepstral_lifter): ...@@ -543,7 +558,7 @@ def _get_lifter_coeffs(num_ceps, cepstral_lifter):
# returns size (num_ceps) # returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs) # Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected. # coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps, dtype=torch.get_default_dtype()) i = torch.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter) return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
...@@ -652,6 +667,7 @@ def mfcc( ...@@ -652,6 +667,7 @@ def mfcc(
def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width, def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, window_width,
lowpass_cutoff, lowpass_filter_width): lowpass_cutoff, lowpass_filter_width):
# type: (float, float, int, float, float, int) -> Tuple[Tensor, Tensor]
r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for
resampling as well as the indices in which they are valid. LinearResample (LR) means resampling as well as the indices in which they are valid. LinearResample (LR) means
that the output signal is at linearly spaced intervals (i.e the output signal has a that the output signal is at linearly spaced intervals (i.e the output signal has a
...@@ -699,7 +715,7 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win ...@@ -699,7 +715,7 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win
which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)). which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)).
""" """
assert lowpass_cutoff < min(orig_freq, new_freq) / 2 assert lowpass_cutoff < min(orig_freq, new_freq) / 2
output_t = torch.arange(0, output_samples_in_unit, dtype=torch.get_default_dtype()) / new_freq output_t = torch.arange(0., output_samples_in_unit) / new_freq
min_t = output_t - window_width min_t = output_t - window_width
max_t = output_t + window_width max_t = output_t + window_width
...@@ -732,10 +748,12 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win ...@@ -732,10 +748,12 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win
def _lcm(a, b): def _lcm(a, b):
return abs(a * b) // fractions.gcd(a, b) # type: (int, int) -> int
return abs(a * b) // gcd(a, b)
def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out): def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out):
# type: (int, float, float) -> int
r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that
the output signal is at linearly spaced intervals (i.e the output signal has a the output signal is at linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample
...@@ -780,6 +798,7 @@ def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out): ...@@ -780,6 +798,7 @@ def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out):
def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6):
# type: (Tensor, float, float, int) -> Tensor
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
...@@ -807,7 +826,7 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): ...@@ -807,7 +826,7 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6):
assert lowpass_cutoff * 2 <= min_freq assert lowpass_cutoff * 2 <= min_freq
base_freq = fractions.gcd(int(orig_freq), int(new_freq)) base_freq = gcd(int(orig_freq), int(new_freq))
input_samples_in_unit = int(orig_freq) // base_freq input_samples_in_unit = int(orig_freq) // base_freq
output_samples_in_unit = int(new_freq) // base_freq output_samples_in_unit = int(new_freq) // base_freq
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment