Commit 7a05622e authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Extend fftconvolve to support broadcast-able shapes (#2874)

Summary:
Currently, fftconvolve only accepts the tensors for the exact same leading dimensions.
This commit loosens the restriction to allow shapes that are broadcast-able.

This makes the fftconvolve operation more efficient for cases like signal filtering where one operand (waveform) is larger than the other (filter kernel) and the same filter kernels are applied across channels and batches.

Pull Request resolved: https://github.com/pytorch/audio/pull/2874

Reviewed By: carolineechen

Differential Revision: D41581588

Pulled By: mthrok

fbshipit-source-id: c0117e11b979fb53236cc307a970a461b0e50134
parent 8bde6a54
......@@ -54,16 +54,66 @@ class FunctionalTestImpl(TestBaseMixin):
self.assertEqual(expected, actual)
@nested_params(
[F.convolve, F.fftconvolve],
[(4, 3, 1, 2), (1,)],
[(10, 4), (2, 2, 2)],
@parameterized.expand(
[
# fmt: off
((5, 2, 3), (5, 1, 3)),
((5, 2, 3), (1, 2, 3)),
((5, 2, 3), (1, 1, 3)),
# fmt: on
]
)
def test_convolve_input_leading_dim_check(self, fn, x_shape, y_shape):
def test_fftconvolve_broadcast(self, x_shape, y_shape):
"""fftconvolve works for Tensors for different shapes if they are broadcast-able"""
# 1. Test broad cast case
x = torch.rand(x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(y_shape, dtype=self.dtype, device=self.device)
out1 = F.fftconvolve(x, y)
# 2. Test without broadcast
y_clone = y.expand(x_shape).clone()
assert y is not y_clone
assert y_clone.shape == x.shape
out2 = F.fftconvolve(x, y_clone)
# check that they are same
self.assertEqual(out1, out2)
@parameterized.expand(
[
# fmt: off
# different ndim
(0, F.convolve, (4, 3, 1, 2), (10, 4)),
(0, F.convolve, (4, 3, 1, 2), (2, 2, 2)),
(0, F.convolve, (1, ), (10, 4)),
(0, F.convolve, (1, ), (2, 2, 2)),
(0, F.fftconvolve, (4, 3, 1, 2), (10, 4)),
(0, F.fftconvolve, (4, 3, 1, 2), (2, 2, 2)),
(0, F.fftconvolve, (1, ), (10, 4)),
(0, F.fftconvolve, (1, ), (2, 2, 2)),
# incompatible shape except the last dim
(1, F.convolve, (5, 2, 3), (5, 3, 3)),
(1, F.convolve, (5, 2, 3), (5, 3, 4)),
(1, F.convolve, (5, 2, 3), (5, 3, 5)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 3)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 4)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 5)),
# broadcast-able (only for convolve)
(1, F.convolve, (5, 2, 3), (5, 1, 3)),
(1, F.convolve, (5, 2, 3), (5, 1, 4)),
(1, F.convolve, (5, 2, 3), (5, 1, 5)),
# fmt: on
],
)
def test_convolve_input_leading_dim_check(self, case, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with different leading dimensions."""
x = torch.rand(*x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
with self.assertRaisesRegex(ValueError, "Leading dimensions"):
message = [
"The operands must be the same dimension",
"Leading dimensions of x and y don't match",
"Leading dimensions of x and y are not broadcastable",
][case]
with self.assertRaisesRegex(ValueError, message):
fn(x, y)
def test_add_noise_broadcast(self):
......
......@@ -26,7 +26,7 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
self.assertEqual(ts_output, output)
@nested_params(
[F.convolve, F.fftconvolve],
["convolve", "fftconvolve"],
["full", "valid", "same"],
)
def test_convolve(self, fn, mode):
......@@ -35,7 +35,7 @@ class TorchScriptConsistencyTestImpl(TestBaseMixin):
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
self._assert_consistency(fn, (x, y, mode))
self._assert_consistency(getattr(F, fn), (x, y, mode))
def test_add_noise(self):
leading_dims = (2, 3)
......
......@@ -5,7 +5,7 @@ from torchaudio_unittest.common_utils import nested_params, TestBaseMixin, torch
class Transforms(TestBaseMixin):
@nested_params(
[T.Convolve, T.FFTConvolve],
["Convolve", "FFTConvolve"],
["full", "valid", "same"],
)
def test_Convolve(self, cls, mode):
......@@ -14,7 +14,7 @@ class Transforms(TestBaseMixin):
x = torch.rand(*leading_dims, L_x, dtype=self.dtype, device=self.device)
y = torch.rand(*leading_dims, L_y, dtype=self.dtype, device=self.device)
convolve = cls(mode=mode).to(device=self.device, dtype=self.dtype)
convolve = getattr(T, cls)(mode=mode).to(device=self.device, dtype=self.dtype)
output = convolve(x, y)
ts_output = torch_script(convolve)(x, y)
self.assertEqual(ts_output, output)
......@@ -6,18 +6,27 @@ import torch
from torchaudio.functional.functional import _create_triangular_filterbank
def _check_shape_compatible(x: torch.Tensor, y: torch.Tensor, allow_broadcast: bool) -> None:
if x.ndim != y.ndim:
raise ValueError(f"The operands must be the same dimension (got {x.ndim} and {y.ndim}).")
if not allow_broadcast:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
else:
for i in range(x.ndim - 1):
xi = x.size(i)
yi = y.size(i)
if xi == yi or xi == 1 or yi == 1:
continue
raise ValueError(f"Leading dimensions of x and y are not broadcastable (got {x.shape} and {y.shape}).")
def _check_convolve_mode(mode: str) -> None:
valid_convolve_modes = ["full", "valid", "same"]
if mode not in valid_convolve_modes:
raise ValueError(f"Unrecognized mode value '{mode}'. Please specify one of {valid_convolve_modes}.")
def _check_convolve_inputs(x: torch.Tensor, y: torch.Tensor, mode: str) -> None:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
_check_convolve_mode(mode)
def _apply_convolve_mode(conv_result: torch.Tensor, x_length: int, y_length: int, mode: str) -> torch.Tensor:
valid_convolve_modes = ["full", "valid", "same"]
if mode == "full":
......@@ -48,7 +57,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
(leading dimensions must be broadcast-able to those of ``x``).
mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
......@@ -63,7 +72,8 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_convolve_inputs(x, y, mode)
_check_shape_compatible(x, y, allow_broadcast=True)
_check_convolve_mode(mode)
n = x.size(-1) + y.size(-1) - 1
fresult = torch.fft.rfft(x, n=n) * torch.fft.rfft(y, n=n)
......@@ -99,7 +109,8 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_convolve_inputs(x, y, mode)
_check_shape_compatible(x, y, allow_broadcast=False)
_check_convolve_mode(mode)
x_size, y_size = x.size(-1), y.size(-1)
......
......@@ -85,7 +85,7 @@ class FFTConvolve(torch.nn.Module):
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
(leading dimensions must be broadcast-able to those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment