Unverified Commit 49860425 authored by chin yun yu's avatar chin yun yu Committed by GitHub
Browse files

Replace indexing+matmul operation in lfilter with conv1d (#1318)

parent 33dc817c
......@@ -855,42 +855,25 @@ def lfilter(
assert waveform.device == a_coeffs.device
assert b_coeffs.device == a_coeffs.device
device = waveform.device
dtype = waveform.dtype
n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0)
n_sample_padded = n_sample + n_order - 1
assert n_order > 0
# Pad the input and create output
padded_waveform = torch.zeros(
n_channel, n_sample_padded, dtype=dtype, device=device
)
padded_waveform[:, n_order - 1:] = waveform
padded_output_waveform = torch.zeros(
n_channel, n_sample_padded, dtype=dtype, device=device
)
padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0])
padded_output_waveform = torch.zeros_like(padded_waveform)
# Set up the coefficients matrix
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(0)
b_coeffs_flipped = b_coeffs.flip(0)
# calculate windowed_input_signal in parallel
# create indices of original with shape (n_channel, n_order, n_sample)
window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(
n_order, device=device
).unsqueeze(1)
window_idxs = window_idxs.repeat(n_channel, 1, 1)
window_idxs += (
torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1)
* n_sample_padded
)
window_idxs = window_idxs.long()
# (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
input_signal_windows = torch.matmul(
b_coeffs_flipped, torch.take(padded_waveform, window_idxs)
)
# calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d(
padded_waveform.unsqueeze(1),
b_coeffs_flipped.view(1, 1, -1)
).squeeze(1)
input_signal_windows.div_(a_coeffs[0])
a_coeffs_flipped.div_(a_coeffs[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment