Commit 8273c3f4 authored by engineerchuan's avatar engineerchuan Committed by Vincent QB
Browse files

Make lfilter, and related filters, available (#275)

* Add basic low pass filtering
* Add highpass filtering
* More tests of IIR vs FIR
* Implement convolve function, add tests
* Move lfilter and convolve into functional, more tests
* added additional documentation for convolve and lfilter, renamed functional_filtering to functional_sox_convenience
* Follow naming convention for sample rate in functional
* fix failing vctk manifest test to account for adding more test audios into assets
* Adding documentation for lfilter, biquad, highpass_biquad, lowpass_biquad
* added matrix based implementation of lfilter
* adding python lfilter implementation
* factor out biquad, lowpass, highpass to sox compatibility
parent 4e80df79
...@@ -62,3 +62,23 @@ Functions to perform common audio operations. ...@@ -62,3 +62,23 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: phase_vocoder .. autofunction:: phase_vocoder
:hidden:`lfilter`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: lfilter
:hidden:`biquad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: biquad
:hidden:`lowpass_biquad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: lowpass_biquad
:hidden:`highpass_biquad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: highpass_biquad
...@@ -23,10 +23,13 @@ class TestVCTK(unittest.TestCase): ...@@ -23,10 +23,13 @@ class TestVCTK(unittest.TestCase):
def test_make_manifest(self): def test_make_manifest(self):
audios = vctk.make_manifest(self.test_dirpath) audios = vctk.make_manifest(self.test_dirpath)
files = ['kaldi_file.wav', 'kaldi_file_8000.wav', files = ['kaldi_file.wav', 'kaldi_file_8000.wav',
'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3'] 'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3',
'dtmf_30s_stereo.mp3', 'whitenoise_1min.mp3', 'whitenoise.mp3']
files = [self.get_full_path(file) for file in files] files = [self.get_full_path(file) for file in files]
files.sort()
audios.sort() audios.sort()
self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios)) self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios))
def test_read_audio_downsample_false(self): def test_read_audio_downsample_false(self):
......
from __future__ import absolute_import, division, print_function, unicode_literals
import math
import os
import torch
import torchaudio
import torchaudio.functional as F
import unittest
import common_utils
import time
class TestFunctionalFiltering(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
def test_lfilter_basic(self):
"""
Create a very basic signal,
Then make a simple 4th order delay
The output should be same as the input but shifted
"""
torch.random.manual_seed(42)
waveform = torch.rand(2, 44100 * 10)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=torch.float32)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert torch.allclose(
waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5
)
def test_lfilter(self):
"""
Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
Example
>>> from scipy.signal import iirdesign
>>> b, a = iirdesign(0.2, 0.3, 1, 60)
"""
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
]
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
]
)
filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert len(output_waveform.size()) == 2
assert output_waveform.size(0) == waveform.size(0)
assert output_waveform.size(1) == waveform.size(1)
def test_lowpass(self):
"""
Test biquad lowpass filter, compare to SoX implementation
"""
CUTOFF_FREQ = 3000
noise_filepath = os.path.join(
self.test_dirpath, "assets", "whitenoise.mp3"
)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, sample_rate = torchaudio.load(
noise_filepath, normalization=True
)
output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
def test_highpass(self):
"""
Test biquad highpass filter, compare to SoX implementation
"""
CUTOFF_FREQ = 2000
noise_filepath = os.path.join(
self.test_dirpath, "assets", "whitenoise.mp3"
)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(noise_filepath)
E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
sox_output_waveform, sr = E.sox_build_flow_effects()
waveform, sample_rate = torchaudio.load(
noise_filepath, normalization=True
)
output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)
# TBD - this fails at the 1e-4 level, debug why
assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
def test_perf_biquad_filtering(self):
fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
b0 = 0.4
b1 = 0.2
b2 = 0.9
a0 = 0.7
a1 = 0.2
a2 = 0.6
# SoX method
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fn_sine)
_timing_sox = time.time()
E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
waveform_sox_out, sr = E.sox_build_flow_effects()
_timing_sox_run_time = time.time() - _timing_sox
_timing_lfilter_filtering = time.time()
waveform, sample_rate = torchaudio.load(fn_sine, normalization=True)
waveform_lfilter_out = F.lfilter(
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
_timing_lfilter_run_time = time.time() - _timing_lfilter_filtering
assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
if __name__ == "__main__":
unittest.main()
...@@ -2,40 +2,63 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -2,40 +2,63 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import math import math
import torch import torch
__all__ = [ __all__ = [
'istft', "istft",
'spectrogram', "spectrogram",
'amplitude_to_DB', "amplitude_to_DB",
'create_fb_matrix', "create_fb_matrix",
'create_dct', "create_dct",
'mu_law_encoding', "mu_law_encoding",
'mu_law_decoding', "mu_law_decoding",
'complex_norm', "complex_norm",
'angle', "angle",
'magphase', "magphase",
'phase_vocoder', "phase_vocoder",
"lfilter",
"lowpass_biquad",
"highpass_biquad",
"biquad",
] ]
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore @torch.jit.ignore
def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): def _stft(
waveform,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
return torch.stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) return torch.stft(
waveform,
n_fft,
def istft(stft_matrix, # type: Tensor hop_length,
n_fft, # type: int win_length,
hop_length=None, # type: Optional[int] window,
win_length=None, # type: Optional[int] center,
window=None, # type: Optional[Tensor] pad_mode,
center=True, # type: bool normalized,
pad_mode='reflect', # type: str onesided,
normalized=False, # type: bool )
onesided=True, # type: bool
length=None # type: Optional[int]
): def istft(
stft_matrix, # type: Tensor
n_fft, # type: int
hop_length=None, # type: Optional[int]
win_length=None, # type: Optional[int]
window=None, # type: Optional[Tensor]
center=True, # type: bool
pad_mode="reflect", # type: str
normalized=False, # type: bool
onesided=True, # type: bool
length=None, # type: Optional[int]
):
# type: (...) -> Tensor # type: (...) -> Tensor
r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of ``length``) and it should return the It has the same parameters (+ additional optional parameter of ``length``) and it should return the
...@@ -90,7 +113,7 @@ def istft(stft_matrix, # type: Tensor ...@@ -90,7 +113,7 @@ def istft(stft_matrix, # type: Tensor
(channel, signal_length) or (signal_length) (channel, signal_length) or (signal_length)
""" """
stft_matrix_dim = stft_matrix.dim() stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim)
if stft_matrix_dim == 3: if stft_matrix_dim == 3:
# add a channel dimension # add a channel dimension
...@@ -99,9 +122,13 @@ def istft(stft_matrix, # type: Tensor ...@@ -99,9 +122,13 @@ def istft(stft_matrix, # type: Tensor
dtype = stft_matrix.dtype dtype = stft_matrix.dtype
device = stft_matrix.device device = stft_matrix.device
fft_size = stft_matrix.size(1) fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), ( assert (onesided and n_fft // 2 + 1 == fft_size) or (
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' + not onesided and n_fft == fft_size
'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) ), (
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
+ "Given values were onesided: %s, n_fft: %d, fft_size: %d"
% ("True" if onesided else False, n_fft, fft_size)
)
# use stft defaults for Optionals # use stft defaults for Optionals
if win_length is None: if win_length is None:
...@@ -127,8 +154,9 @@ def istft(stft_matrix, # type: Tensor ...@@ -127,8 +154,9 @@ def istft(stft_matrix, # type: Tensor
# win_length and n_fft are synonymous from here on # win_length and n_fft are synonymous from here on
stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2) stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2)
stft_matrix = torch.irfft(stft_matrix, 1, normalized, stft_matrix = torch.irfft(
onesided, signal_sizes=(n_fft,)) # size (channel, n_frames, n_fft) stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
) # size (channel, n_frames, n_fft)
assert stft_matrix.size(2) == n_fft assert stft_matrix.size(2) == n_fft
n_frames = stft_matrix.size(1) n_frames = stft_matrix.size(1)
...@@ -137,18 +165,23 @@ def istft(stft_matrix, # type: Tensor ...@@ -137,18 +165,23 @@ def istft(stft_matrix, # type: Tensor
# each column of a channel is a frame which needs to be overlap added at the right place # each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames) ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames)
eye = torch.eye(n_fft, requires_grad=False, eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze(
device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft) 1
) # size (n_fft, 1, n_fft)
# this does overlap add where the frames of ytmp are added such that the i'th frame of # this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output # ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.conv_transpose1d( y = torch.nn.functional.conv_transpose1d(
ytmp, eye, stride=hop_length, padding=0) # size (channel, 1, expected_signal_len) ytmp, eye, stride=hop_length, padding=0
) # size (channel, 1, expected_signal_len)
# do the same for the window function # do the same for the window function
window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames) window_sq = (
window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0)
) # size (1, n_fft, n_frames)
window_envelop = torch.nn.functional.conv_transpose1d( window_envelop = torch.nn.functional.conv_transpose1d(
window_sq, eye, stride=hop_length, padding=0) # size (1, 1, expected_signal_len) window_sq, eye, stride=hop_length, padding=0
) # size (1, 1, expected_signal_len)
expected_signal_len = n_fft + hop_length * (n_frames - 1) expected_signal_len = n_fft + hop_length * (n_frames - 1)
assert y.size(2) == expected_signal_len assert y.size(2) == expected_signal_len
...@@ -164,7 +197,9 @@ def istft(stft_matrix, # type: Tensor ...@@ -164,7 +197,9 @@ def istft(stft_matrix, # type: Tensor
# check NOLA non-zero overlap condition # check NOLA non-zero overlap condition
window_envelop_lowest = window_envelop.abs().min() window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest)) assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
window_envelop_lowest
)
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
...@@ -174,7 +209,9 @@ def istft(stft_matrix, # type: Tensor ...@@ -174,7 +209,9 @@ def istft(stft_matrix, # type: Tensor
@torch.jit.script @torch.jit.script
def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized): def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized
):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
r"""Create a spectrogram from a raw audio signal. r"""Create a spectrogram from a raw audio signal.
...@@ -201,8 +238,9 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor ...@@ -201,8 +238,9 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
# default values are consistent with librosa.core.spectrum._spectrogram # default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(waveform, n_fft, hop_length, win_length, window, spec_f = _stft(
True, 'reflect', False, True) waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
)
if normalized: if normalized:
spec_f /= window.pow(2).sum().sqrt() spec_f /= window.pow(2).sum().sqrt()
...@@ -234,8 +272,9 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): ...@@ -234,8 +272,9 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
x_db -= multiplier * db_multiplier x_db -= multiplier * db_multiplier
if top_db is not None: if top_db is not None:
new_x_db_max = torch.tensor(float(x_db.max()) - top_db, new_x_db_max = torch.tensor(
dtype=x_db.dtype, device=x_db.device) float(x_db.max()) - top_db, dtype=x_db.dtype, device=x_db.device
)
x_db = torch.max(x_db, new_x_db_max) x_db = torch.max(x_db, new_x_db_max)
return x_db return x_db
...@@ -263,17 +302,17 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): ...@@ -263,17 +302,17 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels):
freqs = torch.linspace(f_min, f_max, n_freqs) freqs = torch.linspace(f_min, f_max, n_freqs)
# calculate mel freq bins # calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) m_min = 0.0 if f_min == 0 else 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595. * math.log10(1. + (f_max / 700.)) m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_pts = torch.linspace(m_min, m_max, n_mels + 2) m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700. * (10**(m_pts / 2595.) - 1.) f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
# calculate the difference between each mel point and each stft freq point in hertz # calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2) slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles # create overlapping triangles
zero = torch.zeros(1) zero = torch.zeros(1)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes)) fb = torch.max(zero, torch.min(down_slopes, up_slopes))
return fb return fb
...@@ -301,7 +340,7 @@ def create_dct(n_mfcc, n_mels, norm): ...@@ -301,7 +340,7 @@ def create_dct(n_mfcc, n_mels, norm):
if norm is None: if norm is None:
dct *= 2.0 dct *= 2.0
else: else:
assert norm == 'ortho' assert norm == "ortho"
dct[0] *= 1.0 / math.sqrt(2.0) dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(n_mels)) dct *= math.sqrt(2.0 / float(n_mels))
return dct.t() return dct.t()
...@@ -323,12 +362,11 @@ def mu_law_encoding(x, quantization_channels): ...@@ -323,12 +362,11 @@ def mu_law_encoding(x, quantization_channels):
Returns: Returns:
torch.Tensor: Input after mu-law encoding torch.Tensor: Input after mu-law encoding
""" """
mu = quantization_channels - 1. mu = quantization_channels - 1.0
if not x.is_floating_point(): if not x.is_floating_point():
x = x.to(torch.float) x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype) mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu * x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
torch.abs(x)) / torch.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64) x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
return x_mu return x_mu
...@@ -349,12 +387,12 @@ def mu_law_decoding(x_mu, quantization_channels): ...@@ -349,12 +387,12 @@ def mu_law_decoding(x_mu, quantization_channels):
Returns: Returns:
torch.Tensor: Input after mu-law decoding torch.Tensor: Input after mu-law decoding
""" """
mu = quantization_channels - 1. mu = quantization_channels - 1.0
if not x_mu.is_floating_point(): if not x_mu.is_floating_point():
x_mu = x_mu.to(torch.float) x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype) mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.0
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
return x return x
...@@ -385,7 +423,7 @@ def angle(complex_tensor): ...@@ -385,7 +423,7 @@ def angle(complex_tensor):
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
def magphase(complex_tensor, power=1.): def magphase(complex_tensor, power=1.0):
r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase.
Args: Args:
...@@ -428,13 +466,15 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -428,13 +466,15 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
ndim = complex_specgrams.dim() ndim = complex_specgrams.dim()
time_slice = [slice(None)] * (ndim - 2) time_slice = [slice(None)] * (ndim - 2)
time_steps = torch.arange(0, time_steps = torch.arange(
complex_specgrams.size(-2), 0,
rate, complex_specgrams.size(-2),
device=complex_specgrams.device, rate,
dtype=complex_specgrams.dtype) device=complex_specgrams.device,
dtype=complex_specgrams.dtype,
)
alphas = time_steps % 1. alphas = time_steps % 1.0
phase_0 = angle(complex_specgrams[time_slice + [slice(1)]]) phase_0 = angle(complex_specgrams[time_slice + [slice(1)]])
# Time Padding # Time Padding
...@@ -466,3 +506,149 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -466,3 +506,149 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
return complex_specgrams_stretch return complex_specgrams_stretch
def lfilter(waveform, a_coeffs, b_coeffs):
# type: (Tensor, Tensor, Tensor) -> Tensor
r"""
Performs an IIR filter by evaluating difference equation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (torch.Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`.
Must be same size as a_coeffs (pad with 0's as necessary).
Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`. Output will be clipped to -1 to 1.
"""
assert(waveform.dtype == torch.float32)
assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2)
n_channels, n_frames = waveform.size()
n_order = a_coeffs.size(0)
assert(n_order > 0)
# Pad the input and create output
padded_waveform = torch.zeros(n_channels, n_frames + n_order - 1)
padded_waveform[:, (n_order - 1):] = waveform
padded_output_waveform = torch.zeros(n_channels, n_frames + n_order - 1)
# Set up the coefficients matrix
# Flip order, repeat, and transpose
a_coeffs_filled = a_coeffs.flip(0).repeat(n_channels, 1).t()
b_coeffs_filled = b_coeffs.flip(0).repeat(n_channels, 1).t()
# Set up a few other utilities
a0_repeated = torch.ones(n_channels) * a_coeffs[0]
ones = torch.ones(n_channels, n_frames)
for i_frame in range(n_frames):
o0 = torch.zeros(n_channels)
windowed_input_signal = padded_waveform[:, i_frame:(i_frame + n_order)]
windowed_output_signal = padded_output_waveform[:, i_frame:(i_frame + n_order)]
o0.add_(torch.diag(torch.mm(windowed_input_signal, b_coeffs_filled)))
o0.sub_(torch.diag(torch.mm(windowed_output_signal, a_coeffs_filled)))
o0.div_(a0_repeated)
padded_output_waveform[:, i_frame + n_order - 1] = o0
return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):]))
def biquad(waveform, b0, b1, b2, a0, a1, a2):
# type: (Tensor, float, float, float, float, float, float) -> Tensor
r"""Performs a biquad filter of input tensor. Initial conditions set to 0.
https://en.wikipedia.org/wiki/Digital_biquad_filter
Args:
waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)`
b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
a0 (float): denominator coefficient of current output y[n], typically 1
a1 (float): denominator coefficient of current output y[n-1]
a2 (float): denominator coefficient of current output y[n-2]
Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`
"""
assert(waveform.dtype == torch.float32)
output_waveform = lfilter(
waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
)
return output_waveform
def _dB2Linear(x):
return math.exp(x * math.log(10) / 20.0)
def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, Optional[float]) -> Tensor
r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`
"""
GAIN = 1 # TBD - add as a parameter
w0 = 2 * math.pi * cutoff_freq / sample_rate
A = math.exp(GAIN / 40.0 * math.log(10))
alpha = math.sin(w0) / 2 / Q
mult = _dB2Linear(max(GAIN, 0))
b0 = (1 + math.cos(w0)) / 2
b1 = -1 - math.cos(w0)
b2 = b0
a0 = 1 + alpha
a1 = -2 * math.cos(w0)
a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, Optional[float]) -> Tensor
r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Returns:
output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`
"""
GAIN = 1
w0 = 2 * math.pi * cutoff_freq / sample_rate
A = math.exp(GAIN / 40.0 * math.log(10))
alpha = math.sin(w0) / 2 / Q
mult = _dB2Linear(max(GAIN, 0))
b0 = (1 - math.cos(w0)) / 2
b1 = 1 - math.cos(w0)
b2 = b0
a0 = 1 + alpha
a1 = -2 * math.cos(w0)
a2 = 1 - alpha
return biquad(waveform, b0, b1, b2, a0, a1, a2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment