Commit 2d99fee2 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add convolution transforms (#2811)

Summary:
Adds `torch.nn.Module`-based implementations for convolution and FFT convolution.

Pull Request resolved: https://github.com/pytorch/audio/pull/2811

Reviewed By: carolineechen

Differential Revision: D40881937

Pulled By: hwangjeff

fbshipit-source-id: bfe8969e6178ad4f58981efd4b2720ac006be8de
parent 6bd38512
...@@ -90,6 +90,7 @@ model implementations and application components. ...@@ -90,6 +90,7 @@ model implementations and application components.
prototype.functional prototype.functional
prototype.models prototype.models
prototype.pipelines prototype.pipelines
prototype.transforms
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -20,3 +20,4 @@ imported explicitly, e.g. ...@@ -20,3 +20,4 @@ imported explicitly, e.g.
prototype.functional prototype.functional
prototype.models prototype.models
prototype.pipelines prototype.pipelines
prototype.transforms
.. py:module:: torchaudio.prototype.transforms
torchaudio.prototype.transforms
===============================
.. currentmodule:: torchaudio.prototype.transforms
.. autosummary::
:toctree: generated
:nosignatures:
Convolve
FFTConvolve
from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import Autograd
class AutogradCPUTest(Autograd, PytorchTestCase):
device = "cpu"
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .autograd_test_impl import Autograd
@skipIfNoCuda
class AutogradCUDATest(Autograd, PytorchTestCase):
device = "cuda"
from typing import List
import torch
import torchaudio.prototype.transforms as T
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
class Autograd(TestBaseMixin):
def assert_grad(
self,
transform: torch.nn.Module,
inputs: List[torch.Tensor],
*,
nondet_tol: float = 0.0,
):
transform = transform.to(dtype=torch.float64, device=self.device)
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
# `torch.cdouble`, when the default eps and tolerance values are used.
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=torch.cdouble if i.is_complex() else torch.double, device=self.device)
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
@nested_params(
[T.Convolve, T.FFTConvolve],
["full", "valid", "same"],
)
def test_Convolve(self, cls, mode):
leading_dims = (4, 3, 2)
L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
convolve = cls(mode=mode).to(dtype=self.dtype, device=self.device)
self.assert_grad(convolve, [x, y])
import torch
import torchaudio.prototype.transforms as T
from torchaudio_unittest.common_utils import nested_params, TorchaudioTestCase
class BatchConsistencyTest(TorchaudioTestCase):
@nested_params(
[T.Convolve, T.FFTConvolve],
["full", "valid", "same"],
)
def test_Convolve(self, cls, mode):
leading_dims = (2, 3)
L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
convolve = cls(mode=mode)
actual = convolve(x, y)
expected = torch.stack(
[
torch.stack(
[convolve(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0)).squeeze(0) for j in range(leading_dims[1])]
)
for i in range(leading_dims[0])
]
)
self.assertEqual(expected, actual)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms
class TestTransformsFloat32(Transforms, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .torchscript_consistency_impl import Transforms
@skipIfNoCuda
class TestTransformsFloat32(Transforms, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
import torchaudio.prototype.transforms as T
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin, torch_script
class Transforms(TestBaseMixin):
@nested_params(
[T.Convolve, T.FFTConvolve],
["full", "valid", "same"],
)
def test_Convolve(self, cls, mode):
leading_dims = (2, 3, 2)
L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
convolve = cls(mode=mode).to(device=self.device, dtype=self.dtype)
output = convolve(x, y)
ts_output = torch_script(convolve)(x, y)
self.assertEqual(ts_output, output)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .transforms_test_impl import TransformsTestImpl
class TransformsFloat32CPUTest(TransformsTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class TransformsFloat64CPUTest(TransformsTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .transforms_test_impl import TransformsTestImpl
@skipIfNoCuda
class TransformsFloat32CUDATest(TransformsTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class TransformsFloat64CUDATest(TransformsTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import numpy as np
import torch
import torchaudio.prototype.transforms as T
from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
class TransformsTestImpl(TestBaseMixin):
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_Convolve(self, leading_dims, lengths, mode):
"""Check that convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
convolve = T.Convolve(mode=mode).to(self.device)
actual = convolve(x, y)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y))
expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy(), mode=mode)
for i in range(num_signals)
]
expected = torch.tensor(np.array(expected))
expected = expected.reshape(leading_dims + (-1,))
self.assertEqual(expected, actual)
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_FFTConvolve(self, leading_dims, lengths, mode):
"""Check that fftconvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
convolve = T.FFTConvolve(mode=mode).to(self.device)
actual = convolve(x, y)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1, mode=mode)
expected = torch.tensor(expected)
self.assertEqual(expected, actual)
import torch import torch
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor, mode: str) -> None: def _check_convolve_mode(mode: str) -> None:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
valid_convolve_modes = ["full", "valid", "same"] valid_convolve_modes = ["full", "valid", "same"]
if mode not in valid_convolve_modes: if mode not in valid_convolve_modes:
raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.") raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor, mode: str) -> None:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
_check_convolve_mode(mode)
def _apply_convolve_mode(conv_result: torch.Tensor, x_length: int, y_length: int, mode: str) -> torch.Tensor: def _apply_convolve_mode(conv_result: torch.Tensor, x_length: int, y_length: int, mode: str) -> torch.Tensor:
valid_convolve_modes = ["full", "valid", "same"] valid_convolve_modes = ["full", "valid", "same"]
if mode == "full": if mode == "full":
......
from ._transforms import Convolve, FFTConvolve
__all__ = ["Convolve", "FFTConvolve"]
import torch
from torchaudio.prototype.functional import convolve, fftconvolve
from torchaudio.prototype.functional.functional import _check_convolve_mode
class Convolve(torch.nn.Module):
r"""
Convolves inputs along their last dimension using the direct method.
Note that, in contrast to :class:`torch.nn.Conv1d`, which actually applies the valid cross-correlation
operator, this module applies the true `convolution`_ operator.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`, where
`N` and `M` are the trailing dimensions of the two inputs. (Default)
* "valid": Returns the segment of the full convolution result corresponding to where
the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
* "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
def __init__(self, mode: str = "full") -> None:
_check_convolve_mode(mode)
super().__init__()
self.mode = mode
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
"""
return convolve(x, y, mode=self.mode)
class FFTConvolve(torch.nn.Module):
r"""
Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this module
is generally much faster than :class:`Convolve`.
Note that, in contrast to :class:`torch.nn.Conv1d`, which actually applies the valid cross-correlation
operator, this module applies the true `convolution`_ operator.
Also note that this module can only output float tensors (int tensor inputs will be cast to float).
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Args:
mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`, where
`N` and `M` are the trailing dimensions of the two inputs. (Default)
* "valid": Returns the segment of the full convolution result corresponding to where
the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
* "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
def __init__(self, mode: str = "full") -> None:
_check_convolve_mode(mode)
super().__init__()
self.mode = mode
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
r"""
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
"""
return fftconvolve(x, y, mode=self.mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment