Commit f3365ecf authored by David Pollack's avatar David Pollack Committed by Vincent QB
Browse files

Improve lfilter functional (#374)



* Simplify lfilter functional

* use `torch.clamp` instead of `torch.min(..., torch.max(...))`
* remove unneeded creation of ones tensor for previous method

The current lfilter function uses min and max to essentially do a clamp
function.  I changed the code to use clamp instead.  It is more readable
than the previous version.

FYI, if you want to keep the previous way, you could make a
broadcastable tensor of size 1 instead of creating a tensor the size of
the input.
Signed-off-by: default avatarDavid Pollack <david@da3.net>

* Parallelize waveform windows calculation

I've parallelized the calculation of the waveform windows and also
removed the inefficient calculation within the for-loop.
Signed-off-by: default avatarDavid Pollack <david@da3.net>

* Refactoring and minor readability changes
Signed-off-by: default avatarDavid Pollack <david@da3.net>

* Remove one more creation of a temporary tensor
Signed-off-by: default avatarDavid Pollack <david@da3.net>
parent 774ebc78
......@@ -567,39 +567,36 @@ def lfilter(waveform, a_coeffs, b_coeffs):
dtype = waveform.dtype
n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0)
n_sample_padded = n_sample + n_order - 1
assert(n_order > 0)
# Pad the input and create output
padded_waveform = torch.zeros(n_channel, n_sample + n_order - 1, dtype=dtype, device=device)
padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
padded_waveform[:, (n_order - 1):] = waveform
padded_output_waveform = torch.zeros(n_channel, n_sample + n_order - 1, dtype=dtype, device=device)
padded_output_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
# Set up the coefficients matrix
# Flip order, repeat, and transpose
a_coeffs_filled = a_coeffs.flip(0).repeat(n_channel, 1).t()
b_coeffs_filled = b_coeffs.flip(0).repeat(n_channel, 1).t()
# Set up a few other utilities
a0_repeated = torch.ones(n_channel, dtype=dtype, device=device) * a_coeffs[0]
ones = torch.ones(n_channel, n_sample, dtype=dtype, device=device)
for i_sample in range(n_sample):
o0 = torch.zeros(n_channel, dtype=dtype, device=device)
windowed_input_signal = padded_waveform[:, i_sample:(i_sample + n_order)]
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(0)
b_coeffs_flipped = b_coeffs.flip(0)
# calculate windowed_input_signal in parallel
# create indices of original with shape (n_channel, n_order, n_sample)
window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(n_order, device=device).unsqueeze(1)
window_idxs = window_idxs.repeat(n_channel, 1, 1)
window_idxs += (torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
window_idxs = window_idxs.long()
# (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
input_signal_windows = torch.matmul(b_coeffs_flipped, torch.take(padded_waveform, window_idxs))
for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
o0.add_(torch.diag(torch.mm(windowed_input_signal, b_coeffs_filled)))
o0.sub_(torch.diag(torch.mm(windowed_output_signal, a_coeffs_filled)))
o0.div_(a0_repeated)
o0.sub_(torch.mv(windowed_output_signal, a_coeffs_flipped))
o0.div_(a_coeffs[0])
padded_output_waveform[:, i_sample + n_order - 1] = o0
output = torch.min(
ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])
)
output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)
# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment