Commit 69e8dbb2 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add filter_waveform (#2928)

Summary:
This commit adds "filter_waveform" prototype function.

This function can apply non-stationary filters across the time.
It also performs cropping at the end to compensate the delay introduced by filtering.
The figure bellow illustrates this.

See [subtractive_synthesis_tutorial](https://output.circle-artifacts.com/output/job/5233fda9-dadb-4710-9389-7e8ac20a062f/artifacts/0/docs/tutorials/subtractive_synthesis_tutorial.html) for example usages.

![figure](https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png)

Pull Request resolved: https://github.com/pytorch/audio/pull/2928

Reviewed By: carolineechen

Differential Revision: D42199955

Pulled By: mthrok

fbshipit-source-id: e822510ab8df98393919bea33768f288f4d661b2
parent 1706a72f
......@@ -47,6 +47,7 @@ DSP
:nosignatures:
adsr_envelope
filter_waveform
extend_pitch
oscillator_bank
sinc_impulse_response
......
......@@ -91,3 +91,8 @@ class AutogradTestImpl(TestBaseMixin):
def test_freq_ir(self):
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.frequency_impulse_response, (mags,))
def test_filter_waveform(self):
waveform = torch.rand(3, 1, 2, 10, device=self.device, dtype=self.dtype, requires_grad=True)
filters = torch.rand(3, 2, device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.filter_waveform, (waveform, filters))
......@@ -486,6 +486,110 @@ class FunctionalTestImpl(TestBaseMixin):
self.assertEqual(hyp, ref)
@parameterized.expand(
[
# fmt: off
# INPUT: single-dim waveform and 2D filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((15, ), (3, 3)), # filter size (3) is shorter than chunk size (15 // 3 == 5)
((15, ), (3, 5)), # filter size (5) matches than chunk size
((15, ), (3, 7)), # filter size (7) is longer than chunk size
# INPUT: single-dim waveform and 2D filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((15, ), (4, 3)), # filter size (3) is shorter than chunk size (16 // 4 == 4)
((15, ), (4, 4)), # filter size (4) is shorter than chunk size
((15, ), (4, 5)), # filter size (5) is longer than chunk size
# INPUT: multi-dim waveform and 2D filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((7, 2, 15), (3, 3)),
((7, 2, 15), (3, 5)),
((7, 2, 15), (3, 7)),
# INPUT: single-dim waveform and 2D filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((7, 2, 15), (4, 3)),
((7, 2, 15), (4, 4)),
((7, 2, 15), (4, 5)),
# INPUT: multi-dim waveform and multi-dim filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((7, 2, 15), (7, 2, 3, 3)),
((7, 2, 15), (7, 2, 3, 5)),
((7, 2, 15), (7, 2, 3, 7)),
# INPUT: multi-dim waveform and multi-dim filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((7, 2, 15), (7, 2, 4, 3)),
((7, 2, 15), (7, 2, 4, 4)),
((7, 2, 15), (7, 2, 4, 5)),
# INPUT: multi-dim waveform and (broadcast) multi-dim filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((7, 2, 15), (1, 1, 3, 3)),
((7, 2, 15), (1, 1, 3, 5)),
((7, 2, 15), (1, 1, 3, 7)),
# INPUT: multi-dim waveform and (broadcast) multi-dim filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((7, 2, 15), (1, 1, 4, 3)),
((7, 2, 15), (1, 1, 4, 4)),
((7, 2, 15), (1, 1, 4, 5)),
# fmt: on
]
)
def test_filter_waveform_shape(self, waveform_shape, filter_shape):
"""filter_waveform returns the waveform with the same number of samples"""
waveform = torch.randn(waveform_shape, dtype=self.dtype, device=self.device)
filters = torch.randn(filter_shape, dtype=self.dtype, device=self.device)
filtered = F.filter_waveform(waveform, filters)
assert filtered.shape == waveform.shape
@nested_params([1, 3, 5], [3, 5, 7, 4, 6, 8])
def test_filter_waveform_delta(self, num_filters, kernel_size):
"""Applying delta kernel preserves the origianl waveform"""
waveform = torch.arange(-10, 10, dtype=self.dtype, device=self.device)
kernel = torch.zeros((num_filters, kernel_size), dtype=self.dtype, device=self.device)
kernel[:, kernel_size // 2] = 1
result = F.filter_waveform(waveform, kernel)
self.assertEqual(waveform, result)
def test_filter_waveform_same(self, kernel_size=5):
"""Applying the same filter returns the original waveform"""
waveform = torch.arange(-10, 10, dtype=self.dtype, device=self.device)
kernel = torch.randn((1, kernel_size), dtype=self.dtype, device=self.device)
kernels = torch.cat([kernel] * 3)
out1 = F.filter_waveform(waveform, kernel)
out2 = F.filter_waveform(waveform, kernels)
self.assertEqual(out1, out2)
def test_filter_waveform_diff(self):
"""Filters are applied from the first to the last"""
kernel_size = 3
waveform = torch.arange(-10, 10, dtype=self.dtype, device=self.device)
kernels = torch.randn((2, kernel_size), dtype=self.dtype, device=self.device)
# use both filters.
mix = F.filter_waveform(waveform, kernels)
# use only one of them
ref1 = F.filter_waveform(waveform[:10], kernels[0:1])
ref2 = F.filter_waveform(waveform[10:], kernels[1:2])
print("mix:", mix)
print("ref1:", ref1)
print("ref2:", ref2)
# The first filter is effective in the first half
self.assertEqual(mix[:10], ref1[:10])
# The second filter is effective in the second half
self.assertEqual(mix[-9:], ref2[-9:])
# the middle portion is where the two filters affect
class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
......
from ._dsp import adsr_envelope, extend_pitch, frequency_impulse_response, oscillator_bank, sinc_impulse_response
from ._dsp import (
adsr_envelope,
extend_pitch,
filter_waveform,
frequency_impulse_response,
oscillator_bank,
sinc_impulse_response,
)
from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed
......@@ -10,6 +17,7 @@ __all__ = [
"deemphasis",
"extend_pitch",
"fftconvolve",
"filter_waveform",
"frequency_impulse_response",
"oscillator_bank",
"preemphasis",
......
......@@ -3,6 +3,8 @@ from typing import List, Optional, Union
import torch
from .functional import fftconvolve
def oscillator_bank(
frequencies: torch.Tensor,
......@@ -306,3 +308,99 @@ def frequency_impulse_response(magnitudes):
device, dtype = magnitudes.device, magnitudes.dtype
window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
return ir * window
def _overlap_and_add(waveform, stride):
num_frames, frame_size = waveform.shape[-2:]
numel = (num_frames - 1) * stride + frame_size
buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype)
for i in range(num_frames):
start = i * stride
end = start + frame_size
buffer[..., start:end] += waveform[..., i, :]
return buffer
def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1):
"""Applies filters along time axis of the given waveform.
This function applies the given filters along time axis in the following manner:
1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters.
2. Filter each chunk with corresponding filter.
3. Place the filtered chunks at the original indices while adding up the overlapping parts.
4. Crop the resulting waveform so that delay introduced by the filter is removed and its length
matches that of the input waveform.
The following figure illustrates this.
.. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png
.. note::
If the number of filters is one, then the operation becomes stationary.
i.e. the same filtering is applied across the time axis.
Args:
waveform (Tensor): Shape `(..., time)`.
kernels (Tensor): Impulse responses.
Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or
`(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is
the dimension of waveform.
In case of 2D input, the same set of filters is used across channels and batches.
Otherwise, different sets of filters are applied. In this case, the shape of
the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform.
delay_compensation (int): Control how the waveform is cropped after full convolution.
If the value is zero or positive, it is interpreted as the length of crop at the
beginning of the waveform. The value cannot be larger than the size of filter kernel.
Otherwise the initial crop is ``filter_size // 2``.
When cropping happens, the waveform is also cropped from the end so that the
length of the resulting waveform matches the input waveform.
Returns:
Tensor: `(..., time)`.
"""
if kernels.ndim not in [2, waveform.ndim + 1]:
raise ValueError(
"`kernels` must be 2 or N+1 dimension where "
f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})"
)
num_filters, filter_size = kernels.shape[-2:]
num_frames = waveform.size(-1)
if delay_compensation > filter_size:
raise ValueError(
"When `delay_compenstation` is provided, it cannot be larger than the size of filters."
f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}"
)
# Transform waveform's time axis into (num_filters x chunk_length) with optional padding
chunk_length = num_frames // num_filters
if num_frames % num_filters > 0:
chunk_length += 1
num_pad = chunk_length * num_filters - num_frames
waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0)
chunked = waveform.unfold(-1, chunk_length, chunk_length)
assert chunked.numel() >= waveform.numel()
# Broadcast kernels
if waveform.ndim + 1 > kernels.ndim:
expand_shape = waveform.shape[:-1] + kernels.shape
kernels = kernels.expand(expand_shape)
convolved = fftconvolve(chunked, kernels)
restored = _overlap_and_add(convolved, chunk_length)
# Trim in a way that the number of samples are same as input,
# and the filter delay is compensated
if delay_compensation >= 0:
start = delay_compensation
else:
start = filter_size // 2
num_crops = restored.size(-1) - num_frames
end = num_crops - start
result = restored[..., start:-end]
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment