Commit 86d596d3 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Introduce argument 'mode' for convolution functions (#2801)

Summary:
Introduces argument 'mode' for convolution functions, following SciPy's convention.

Pull Request resolved: https://github.com/pytorch/audio/pull/2801

Reviewed By: nateanl

Differential Revision: D40805405

Pulled By: hwangjeff

fbshipit-source-id: 8f0006ffe9e3945b4b17f44c4cfa1adb265c20ef
parent e6bd346e
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import TestBaseMixin
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
class AutogradTestImpl(TestBaseMixin):
@parameterized.expand(
[
(F.convolve,),
(F.fftconvolve,),
]
@nested_params(
[F.convolve, F.fftconvolve],
["full", "valid", "same"],
)
def test_convolve(self, fn):
def test_convolve(self, fn, mode):
leading_dims = (4, 3, 2)
L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device, requires_grad=True)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True)
self.assertTrue(gradcheck(fn, (x, y)))
self.assertTrue(gradgradcheck(fn, (x, y)))
self.assertTrue(gradcheck(fn, (x, y, mode)))
self.assertTrue(gradgradcheck(fn, (x, y, mode)))
def test_add_noise(self):
leading_dims = (5, 2, 3)
......
......@@ -6,17 +6,20 @@ from torchaudio_unittest.common_utils import nested_params, TorchaudioTestCase
class BatchConsistencyTest(TorchaudioTestCase):
@nested_params(
[F.convolve, F.fftconvolve],
["full", "valid", "same"],
)
def test_convolve(self, fn):
def test_convolve(self, fn, mode):
leading_dims = (2, 3)
L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
actual = fn(x, y)
actual = fn(x, y, mode)
expected = torch.stack(
[
torch.stack([fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0)).squeeze(0) for j in range(leading_dims[1])])
torch.stack(
[fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0), mode).squeeze(0) for j in range(leading_dims[1])]
)
for i in range(leading_dims[0])
]
)
......
......@@ -10,21 +10,22 @@ class FunctionalTestImpl(TestBaseMixin):
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_convolve_numerics(self, leading_dims, lengths):
def test_convolve_numerics(self, leading_dims, lengths, mode):
"""Check that convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.convolve(x, y)
actual = F.convolve(x, y, mode=mode)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y))
expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy())
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy(), mode=mode)
for i in range(num_signals)
]
expected = torch.tensor(np.array(expected))
......@@ -35,17 +36,18 @@ class FunctionalTestImpl(TestBaseMixin):
@nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)],
["full", "valid", "same"],
)
def test_fftconvolve_numerics(self, leading_dims, lengths):
def test_fftconvolve_numerics(self, leading_dims, lengths, mode):
"""Check that fftconvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.fftconvolve(x, y)
actual = F.fftconvolve(x, y, mode=mode)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1, mode=mode)
expected = torch.tensor(expected)
self.assertEqual(expected, actual)
......
import torch
import torchaudio.prototype.functional as F
from parameterized import parameterized
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin, torch_script
class TorchScriptConsistencyTestImpl(TestBaseMixin):
......@@ -24,19 +23,17 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
output = output.shape
self.assertEqual(ts_output, output)
@parameterized.expand(
[
(F.convolve,),
(F.fftconvolve,),
]
@nested_params(
[F.convolve, F.fftconvolve],
["full", "valid", "same"],
)
def test_convolve(self, fn):
def test_convolve(self, fn, mode):
leading_dims = (2, 3, 2)
L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(fn, (x, y))
self._assert_consistency(fn, (x, y, mode))
def test_add_noise(self):
leading_dims = (2, 3)
......
import torch
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor) -> None:
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor, mode: str) -> None:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
valid_convolve_modes = ["full", "valid", "same"]
if mode not in valid_convolve_modes:
raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
def _apply_convolve_mode(conv_result: torch.Tensor, x_length: int, y_length: int, mode: str) -> torch.Tensor:
valid_convolve_modes = ["full", "valid", "same"]
if mode == "full":
return conv_result
elif mode == "valid":
target_length = max(x_length, y_length) - min(x_length, y_length) + 1
start_idx = (conv_result.size(-1) - target_length) // 2
return conv_result[..., start_idx : start_idx + target_length]
elif mode == "same":
start_idx = (conv_result.size(-1) - x_length) // 2
return conv_result[..., start_idx : start_idx + x_length]
else:
raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tensor:
r"""
Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this function
is generally much faster than :meth:`convolve`.
......@@ -22,22 +40,29 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
mode (bool, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
* "valid": Returns the segment of the full convolution result corresponding to where
the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
* "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., N + M - 1)`, where
the leading dimensions match those of ``x``.
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_convolve_inputs(x, y)
_check_convolve_inputs(x, y, mode)
n = x.size(-1) + y.size(-1) - 1
fresult = torch.fft.rfft(x, n=n) * torch.fft.rfft(y, n=n)
return torch.fft.irfft(fresult, n=n)
result = torch.fft.irfft(fresult, n=n)
return _apply_convolve_mode(result, x.size(-1), y.size(-1), mode)
def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tensor:
r"""
Convolves inputs along their last dimension using the direct method.
Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation
......@@ -51,15 +76,23 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
mode (bool, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
* "valid": Returns the segment of the full convolution result corresponding to where
the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
* "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., N + M - 1)`, where
the leading dimensions match those of ``x``.
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_convolve_inputs(x, y)
_check_convolve_inputs(x, y, mode)
x_size, y_size = x.size(-1), y.size(-1)
if x.size(-1) < y.size(-1):
x, y = y, x
......@@ -75,7 +108,8 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
padding=reshaped_y.size(-1) - 1,
)
output_shape = x.shape[:-1] + (-1,)
return output.reshape(output_shape)
result = output.reshape(output_shape)
return _apply_convolve_mode(result, x_size, y_size, mode)
def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment