Commit a49edea5 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Enable broadcasting for inputs to convolve (#3061)

Summary:
Relaxes input dimension matching constraint on `convolve` to enable broadcasting for inputs.

Pull Request resolved: https://github.com/pytorch/audio/pull/3061

Reviewed By: mthrok

Differential Revision: D43298078

Pulled By: hwangjeff

fbshipit-source-id: a6cc36674754523b88390fac0a05f06562921319
parent fb932674
......@@ -937,26 +937,22 @@ class Functional(TestBaseMixin):
self.assertEqual(expected, actual)
@parameterized.expand(
[
# fmt: off
((5, 2, 3), (5, 1, 3)),
((5, 2, 3), (1, 2, 3)),
((5, 2, 3), (1, 1, 3)),
# fmt: on
]
@nested_params(
["convolve", "fftconvolve"],
[(5, 2, 3)],
[(5, 1, 3), (1, 2, 3), (1, 1, 3)],
)
def test_fftconvolve_broadcast(self, x_shape, y_shape):
"""fftconvolve works for Tensors for different shapes if they are broadcast-able"""
# 1. Test broad cast case
def test_convolve_broadcast(self, fn, x_shape, y_shape):
"""convolve works for Tensors for different shapes if they are broadcast-able"""
# 1. Test broadcast case
x = torch.rand(x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(y_shape, dtype=self.dtype, device=self.device)
out1 = F.fftconvolve(x, y)
out1 = getattr(F, fn)(x, y)
# 2. Test without broadcast
y_clone = y.expand(x_shape).clone()
assert y is not y_clone
assert y_clone.shape == x.shape
out2 = F.fftconvolve(x, y_clone)
out2 = getattr(F, fn)(x, y_clone)
# check that they are same
self.assertEqual(out1, out2)
......@@ -972,28 +968,23 @@ class Functional(TestBaseMixin):
(0, F.fftconvolve, (4, 3, 1, 2), (2, 2, 2)),
(0, F.fftconvolve, (1, ), (10, 4)),
(0, F.fftconvolve, (1, ), (2, 2, 2)),
# incompatible shape except the last dim
# non-broadcastable leading dimensions
(1, F.convolve, (5, 2, 3), (5, 3, 3)),
(1, F.convolve, (5, 2, 3), (5, 3, 4)),
(1, F.convolve, (5, 2, 3), (5, 3, 5)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 3)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 4)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 5)),
# broadcast-able (only for convolve)
(1, F.convolve, (5, 2, 3), (5, 1, 3)),
(1, F.convolve, (5, 2, 3), (5, 1, 4)),
(1, F.convolve, (5, 2, 3), (5, 1, 5)),
(1, F.fftconvolve, (5, 2, 3), (5, 3, 3)),
(1, F.fftconvolve, (5, 2, 3), (5, 3, 4)),
(1, F.fftconvolve, (5, 2, 3), (5, 3, 5)),
# fmt: on
],
)
def test_convolve_input_leading_dim_check(self, case, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with different leading dimensions."""
def test_convolve_input_dim_check(self, case, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with incompatible dimensions."""
x = torch.rand(*x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
message = [
"The operands must be the same dimension",
"Leading dimensions of x and y don't match",
"Leading dimensions of x and y are not broadcastable",
][case]
with self.assertRaisesRegex(ValueError, message):
......
......@@ -2295,13 +2295,10 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
return specgram_enhanced
def _check_shape_compatible(x: torch.Tensor, y: torch.Tensor, allow_broadcast: bool) -> None:
def _check_shape_compatible(x: torch.Tensor, y: torch.Tensor) -> None:
if x.ndim != y.ndim:
raise ValueError(f"The operands must be the same dimension (got {x.ndim} and {y.ndim}).")
if not allow_broadcast:
if x.shape[:-1] != y.shape[:-1]:
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).")
else:
for i in range(x.ndim - 1):
xi = x.size(i)
yi = y.size(i)
......@@ -2346,7 +2343,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must be broadcast-able to those of ``x``).
(leading dimensions must be broadcast-able with those of ``x``).
mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
......@@ -2361,7 +2358,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_shape_compatible(x, y, allow_broadcast=True)
_check_shape_compatible(x, y)
_check_convolve_mode(mode)
n = x.size(-1) + y.size(-1) - 1
......@@ -2383,7 +2380,7 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
(leading dimensions must be broadcast-able with those of ``x``).
mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
......@@ -2398,7 +2395,7 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
.. _convolution:
https://en.wikipedia.org/wiki/Convolution
"""
_check_shape_compatible(x, y, allow_broadcast=False)
_check_shape_compatible(x, y)
_check_convolve_mode(mode)
x_size, y_size = x.size(-1), y.size(-1)
......@@ -2406,6 +2403,11 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
if x.size(-1) < y.size(-1):
x, y = y, x
if x.shape[:-1] != y.shape[:-1]:
new_shape = [max(i, j) for i, j in zip(x.shape[:-1], y.shape[:-1])]
x = x.broadcast_to(new_shape + [x.shape[-1]])
y = y.broadcast_to(new_shape + [y.shape[-1]])
num_signals = torch.tensor(x.shape[:-1]).prod()
reshaped_x = x.reshape((int(num_signals), x.size(-1)))
reshaped_y = y.reshape((int(num_signals), y.size(-1)))
......
......@@ -1844,7 +1844,7 @@ class Convolve(torch.nn.Module):
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``).
(leading dimensions must be broadcast-able with those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
......@@ -1889,7 +1889,7 @@ class FFTConvolve(torch.nn.Module):
Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must be broadcast-able to those of ``x``).
(leading dimensions must be broadcast-able with those of ``x``).
Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment