Commit b396157d authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add convolution operator (#2602)

Summary:
Adds functions `convolve` and `fftconvolve`, which compute the convolution of two tensors along their trailing dimension. The former performs the convolution directly, whereas the latter performs it using FFT.

Pull Request resolved: https://github.com/pytorch/audio/pull/2602

Reviewed By: nateanl, mthrok

Differential Revision: D38450771

Pulled By: hwangjeff

fbshipit-source-id: b2d1e063ba21eafeddf317d60749e7120b14292b
parent 33485b8c
......@@ -61,6 +61,7 @@ Prototype API References
prototype
prototype.ctc_decoder
prototype.functional
prototype.models
prototype.pipelines
......
torchaudio.prototype.functional
===============================
.. py:module:: torchaudio.prototype.functional
.. currentmodule:: torchaudio.prototype.functional
convolve
~~~~~~~~
.. autofunction:: convolve
fftconvolve
~~~~~~~~~~~
.. autofunction:: fftconvolve
......@@ -17,5 +17,6 @@ imported explicitly, e.g.
import torchaudio.prototype.models
.. toctree::
prototype.models
prototype.pipelines
prototype.functional
prototype.models
prototype.pipelines
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestImpl
class TestAutogradCPUFloat64(AutogradTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .autograd_test_impl import AutogradTestImpl
@skipIfNoCuda
class TestAutogradCUDAFloat64(AutogradTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import TestBaseMixin
class AutogradTestImpl(TestBaseMixin):
@parameterized.expand(
[
(F.convolve,),
(F.fftconvolve,),
]
)
def test_convolve(self, fn):
leading_dims = (4, 3, 2)
L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device, requires_grad=True)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True)
self.assertTrue(gradcheck(fn, (x, y)))
self.assertTrue(gradgradcheck(fn, (x, y)))
import torch
import torchaudio.prototype.functional as F
from torchaudio_unittest.common_utils import nested_params, TorchaudioTestCase
class BatchConsistencyTest(TorchaudioTestCase):
@nested_params(
[F.convolve, F.fftconvolve],
)
def test_convolve(self, fn):
leading_dims = (2, 3)
L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
actual = fn(x, y)
expected = torch.stack(
[
torch.stack([fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0)).squeeze(0) for j in range(leading_dims[1])])
for i in range(leading_dims[0])
]
)
self.assertEqual(expected, actual)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .functional_test_impl import FunctionalTestImpl
class FunctionalFloat32CPUTest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class FunctionalFloat64CPUTest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_test_impl import FunctionalTestImpl
@skipIfNoCuda
class FunctionalFloat32CUDATest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class FunctionalFloat64CUDATest(FunctionalTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import numpy as np
import torch
import torchaudio.prototype.functional as F
from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
class FunctionalTestImpl(TestBaseMixin):
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
)
def test_convolve_numerics(self, leading_dims, lengths):
"""Check that convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.convolve(x, y)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y))
expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy())
for i in range(num_signals)
]
expected = torch.tensor(np.array(expected))
expected = expected.reshape(leading_dims + (-1,))
self.assertEqual(expected, actual)
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
)
def test_fftconvolve_numerics(self, leading_dims, lengths):
"""Check that fftconvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.fftconvolve(x, y)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1)
expected = torch.tensor(expected)
self.assertEqual(expected, actual)
@nested_params(
[F.convolve, F.fftconvolve],
[(4, 3, 1, 2), (1,)],
[(10, 4), (2, 2, 2)],
)
def test_convolve_input_leading_dim_check(self, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with different leading dimensions."""
x = torch.rand(*x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, "Leading dimensions"):
fn(x, y)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_test_impl import TorchScriptConsistencyTestImpl
class TorchScriptConsistencyCPUFloat32Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TorchScriptConsistencyCPUFloat64Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .torchscript_consistency_test_impl import TorchScriptConsistencyTestImpl
@skipIfNoCuda
class TorchScriptConsistencyCUDAFloat32Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class TorchScriptConsistencyCUDAFloat64Test(TorchScriptConsistencyTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class TorchScriptConsistencyTestImpl(TestBaseMixin):
def _assert_consistency(self, func, inputs, shape_only=False):
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(device=self.device, dtype=self.dtype)
inputs_.append(i)
ts_func = torch_script(func)
torch.random.manual_seed(40)
output = func(*inputs_)
torch.random.manual_seed(40)
ts_output = ts_func(*inputs_)
if shape_only:
ts_output = ts_output.shape
output = output.shape
self.assertEqual(ts_output, output)
@parameterized.expand(
[
(F.convolve,),
(F.fftconvolve,),
]
)
def test_convolve(self, fn):
leading_dims = (2, 3, 2)
L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(fn, (x, y))
from .functional import convolve, fftconvolve
__all__ = ["convolve", "fftconvolve"]
import torch
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor) -> None:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this function
is generally much faster than :meth:`convolve`.
Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation
operator, this function applies the true `convolution`_ operator.
Also note that this function can only output float tensors (int tensor inputs will be cast to float).
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
x (torch.Tensor): First convolution operand, with shape `(*, N)`.
y (torch.Tensor): Second convolution operand, with shape `(*, M)`
(leading dimensions must match those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(*, N + M - 1)`, where
the leading dimensions match those of ``x``.
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_convolve_inputs(x, y)
n = x.size(-1) + y.size(-1) - 1
fresult = torch.fft.rfft(x, n=n) * torch.fft.rfft(y, n=n)
return torch.fft.irfft(fresult, n=n)
def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Convolves inputs along their last dimension using the direct method.
Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation
operator, this function applies the true `convolution`_ operator.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
x (torch.Tensor): First convolution operand, with shape `(*, N)`.
y (torch.Tensor): Second convolution operand, with shape `(*, M)`
(leading dimensions must match those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(*, N + M - 1)`, where
the leading dimensions match those of ``x``.
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_convolve_inputs(x, y)
if x.size(-1) < y.size(-1):
x, y = y, x
num_signals = torch.tensor(x.shape[:-1]).prod()
reshaped_x = x.reshape((int(num_signals), x.size(-1)))
reshaped_y = y.reshape((int(num_signals), y.size(-1)))
output = torch.nn.functional.conv1d(
input=reshaped_x,
weight=reshaped_y.flip(-1).unsqueeze(1),
stride=1,
groups=reshaped_x.size(0),
padding=reshaped_y.size(-1) - 1,
)
output_shape = x.shape[:-1] + (-1,)
return output.reshape(output_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment