Commit 86d596d3 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Introduce argument 'mode' for convolution functions (#2801)

Summary:
Introduces argument 'mode' for convolution functions, following SciPy's convention.

Pull Request resolved: https://github.com/pytorch/audio/pull/2801

Reviewed By: nateanl

Differential Revision: D40805405

Pulled By: hwangjeff

fbshipit-source-id: 8f0006ffe9e3945b4b17f44c4cfa1adb265c20ef
parent e6bd346e
import torch import torch
import torchaudio.prototype.functional as F import torchaudio.prototype.functional as F
from parameterized import parameterized
from torch.autograd import gradcheck, gradgradcheck from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import TestBaseMixin from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
class AutogradTestImpl(TestBaseMixin): class AutogradTestImpl(TestBaseMixin):
@parameterized.expand( @nested_params(
[ [F.convolve, F.fftconvolve],
(F.convolve,), ["full", "valid", "same"],
(F.fftconvolve,),
]
) )
def test_convolve(self, fn): def test_convolve(self, fn, mode):
leading_dims = (4, 3, 2) leading_dims = (4, 3, 2)
L_x, L_y = 23, 40 L_x, L_y = 23, 40
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device, requires_grad=True) x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device, requires_grad=True)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device, requires_grad=True)
self.assertTrue(gradcheck(fn, (x, y))) self.assertTrue(gradcheck(fn, (x, y, mode)))
self.assertTrue(gradgradcheck(fn, (x, y))) self.assertTrue(gradgradcheck(fn, (x, y, mode)))
def test_add_noise(self): def test_add_noise(self):
leading_dims = (5, 2, 3) leading_dims = (5, 2, 3)
......
...@@ -6,17 +6,20 @@ from torchaudio_unittest.common_utils import nested_params, TorchaudioTestCase ...@@ -6,17 +6,20 @@ from torchaudio_unittest.common_utils import nested_params, TorchaudioTestCase
class BatchConsistencyTest(TorchaudioTestCase): class BatchConsistencyTest(TorchaudioTestCase):
@nested_params( @nested_params(
[F.convolve, F.fftconvolve], [F.convolve, F.fftconvolve],
["full", "valid", "same"],
) )
def test_convolve(self, fn): def test_convolve(self, fn, mode):
leading_dims = (2, 3) leading_dims = (2, 3)
L_x, L_y = 89, 43 L_x, L_y = 89, 43
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device) x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
actual = fn(x, y) actual = fn(x, y, mode)
expected = torch.stack( expected = torch.stack(
[ [
torch.stack([fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0)).squeeze(0) for j in range(leading_dims[1])]) torch.stack(
[fn(x[i, j].unsqueeze(0), y[i, j].unsqueeze(0), mode).squeeze(0) for j in range(leading_dims[1])]
)
for i in range(leading_dims[0]) for i in range(leading_dims[0])
] ]
) )
......
...@@ -10,21 +10,22 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -10,21 +10,22 @@ class FunctionalTestImpl(TestBaseMixin):
@nested_params( @nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()], [(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)], [(100, 43), (21, 45)],
["full", "valid", "same"],
) )
def test_convolve_numerics(self, leading_dims, lengths): def test_convolve_numerics(self, leading_dims, lengths, mode):
"""Check that convolve returns values identical to those that SciPy produces.""" """Check that convolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device) x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device) y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.convolve(x, y) actual = F.convolve(x, y, mode=mode)
num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1 num_signals = torch.tensor(leading_dims).prod() if leading_dims else 1
x_reshaped = x.reshape((num_signals, L_x)) x_reshaped = x.reshape((num_signals, L_x))
y_reshaped = y.reshape((num_signals, L_y)) y_reshaped = y.reshape((num_signals, L_y))
expected = [ expected = [
signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy()) signal.convolve(x_reshaped[i].detach().cpu().numpy(), y_reshaped[i].detach().cpu().numpy(), mode=mode)
for i in range(num_signals) for i in range(num_signals)
] ]
expected = torch.tensor(np.array(expected)) expected = torch.tensor(np.array(expected))
...@@ -35,17 +36,18 @@ class FunctionalTestImpl(TestBaseMixin): ...@@ -35,17 +36,18 @@ class FunctionalTestImpl(TestBaseMixin):
@nested_params( @nested_params(
[(10, 4), (4, 3, 1, 2), (2,), ()], [(10, 4), (4, 3, 1, 2), (2,), ()],
[(100, 43), (21, 45)], [(100, 43), (21, 45)],
["full", "valid", "same"],
) )
def test_fftconvolve_numerics(self, leading_dims, lengths): def test_fftconvolve_numerics(self, leading_dims, lengths, mode):
"""Check that fftconvolve returns values identical to those that SciPy produces.""" """Check that fftconvolve returns values identical to those that SciPy produces."""
L_x, L_y = lengths L_x, L_y = lengths
x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device) x = torch.rand(*(leading_dims + (L_x,)), dtype=self.dtype, device=self.device)
y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device) y = torch.rand(*(leading_dims + (L_y,)), dtype=self.dtype, device=self.device)
actual = F.fftconvolve(x, y) actual = F.fftconvolve(x, y, mode=mode)
expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1) expected = signal.fftconvolve(x.detach().cpu().numpy(), y.detach().cpu().numpy(), axes=-1, mode=mode)
expected = torch.tensor(expected) expected = torch.tensor(expected)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
......
import torch import torch
import torchaudio.prototype.functional as F import torchaudio.prototype.functional as F
from parameterized import parameterized from torchaudio_unittest.common_utils import nested_params, TestBaseMixin, torch_script
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class TorchScriptConsistencyTestImpl(TestBaseMixin): class TorchScriptConsistencyTestImpl(TestBaseMixin):
...@@ -24,19 +23,17 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin): ...@@ -24,19 +23,17 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
output = output.shape output = output.shape
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
@parameterized.expand( @nested_params(
[ [F.convolve, F.fftconvolve],
(F.convolve,), ["full", "valid", "same"],
(F.fftconvolve,),
]
) )
def test_convolve(self, fn): def test_convolve(self, fn, mode):
leading_dims = (2, 3, 2) leading_dims = (2, 3, 2)
L_x, L_y = 32, 55 L_x, L_y = 32, 55
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device) x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device) y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(fn, (x, y)) self._assert_consistency(fn, (x, y, mode))
def test_add_noise(self): def test_add_noise(self):
leading_dims = (2, 3) leading_dims = (2, 3)
......
import torch import torch
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor) -> None: def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor, mode: str) -> None:
if x.shape[:-1] != y.shape[:-1]: if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).") raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
valid_convolve_modes = ["full", "valid", "same"]
if mode not in valid_convolve_modes:
def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
def _apply_convolve_mode(conv_result: torch.Tensor, x_length: int, y_length: int, mode: str) -> torch.Tensor:
valid_convolve_modes = ["full", "valid", "same"]
if mode == "full":
return conv_result
elif mode == "valid":
target_length = max(x_length, y_length) - min(x_length, y_length) + 1
start_idx = (conv_result.size(-1) - target_length) // 2
return conv_result[..., start_idx : start_idx + target_length]
elif mode == "same":
start_idx = (conv_result.size(-1) - x_length) // 2
return conv_result[..., start_idx : start_idx + x_length]
else:
raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tensor:
r""" r"""
Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this function Convolves inputs along their last dimension using FFT. For inputs with large last dimensions, this function
is generally much faster than :meth:`convolve`. is generally much faster than :meth:`convolve`.
...@@ -22,22 +40,29 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...@@ -22,22 +40,29 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): First convolution operand, with shape `(..., N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``). (leading dimensions must match those of ``x``).
mode (bool, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
* "valid": Returns the segment of the full convolution result corresponding to where
the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
* "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
Returns: Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., N + M - 1)`, where torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
the leading dimensions match those of ``x``. the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
.. _convolution: .. _convolution:
https://en.wikipedia.org/wiki/Convolution https://en.wikipedia.org/wiki/Convolution
""" """
_check_convolve_inputs(x, y) _check_convolve_inputs(x, y, mode)
n = x.size(-1) + y.size(-1) - 1 n = x.size(-1) + y.size(-1) - 1
fresult = torch.fft.rfft(x, n=n) * torch.fft.rfft(y, n=n) fresult = torch.fft.rfft(x, n=n) * torch.fft.rfft(y, n=n)
return torch.fft.irfft(fresult, n=n) result = torch.fft.irfft(fresult, n=n)
return _apply_convolve_mode(result, x.size(-1), y.size(-1), mode)
def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tensor:
r""" r"""
Convolves inputs along their last dimension using the direct method. Convolves inputs along their last dimension using the direct method.
Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation Note that, in contrast to :meth:`torch.nn.functional.conv1d`, which actually applies the valid cross-correlation
...@@ -51,15 +76,23 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...@@ -51,15 +76,23 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): First convolution operand, with shape `(..., N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``). (leading dimensions must match those of ``x``).
mode (bool, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
* "valid": Returns the segment of the full convolution result corresponding to where
the two inputs overlap completely, with shape `(..., max(N, M) - min(N, M) + 1)`.
* "same": Returns the center segment of the full convolution result, with shape `(..., N)`.
Returns: Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., N + M - 1)`, where torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
the leading dimensions match those of ``x``. the leading dimensions match those of ``x`` and `L` is dictated by ``mode``.
.. _convolution: .. _convolution:
https://en.wikipedia.org/wiki/Convolution https://en.wikipedia.org/wiki/Convolution
""" """
_check_convolve_inputs(x, y) _check_convolve_inputs(x, y, mode)
x_size, y_size = x.size(-1), y.size(-1)
if x.size(-1) < y.size(-1): if x.size(-1) < y.size(-1):
x, y = y, x x, y = y, x
...@@ -75,7 +108,8 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...@@ -75,7 +108,8 @@ def convolve(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
padding=reshaped_y.size(-1) - 1, padding=reshaped_y.size(-1) - 1,
) )
output_shape = x.shape[:-1] + (-1,) output_shape = x.shape[:-1] + (-1,)
return output.reshape(output_shape) result = output.reshape(output_shape)
return _apply_convolve_mode(result, x_size, y_size, mode)
def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor) -> torch.Tensor: def add_noise(waveform: torch.Tensor, noise: torch.Tensor, lengths: torch.Tensor, snr: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment