Commit a49edea5 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Enable broadcasting for inputs to convolve (#3061)

Summary:
Relaxes input dimension matching constraint on `convolve` to enable broadcasting for inputs.

Pull Request resolved: https://github.com/pytorch/audio/pull/3061

Reviewed By: mthrok

Differential Revision: D43298078

Pulled By: hwangjeff

fbshipit-source-id: a6cc36674754523b88390fac0a05f06562921319
parent fb932674
...@@ -937,26 +937,22 @@ class Functional(TestBaseMixin): ...@@ -937,26 +937,22 @@ class Functional(TestBaseMixin):
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
@parameterized.expand( @nested_params(
[ ["convolve", "fftconvolve"],
# fmt: off [(5, 2, 3)],
((5, 2, 3), (5, 1, 3)), [(5, 1, 3), (1, 2, 3), (1, 1, 3)],
((5, 2, 3), (1, 2, 3)),
((5, 2, 3), (1, 1, 3)),
# fmt: on
]
) )
def test_fftconvolve_broadcast(self, x_shape, y_shape): def test_convolve_broadcast(self, fn, x_shape, y_shape):
"""fftconvolve works for Tensors for different shapes if they are broadcast-able""" """convolve works for Tensors for different shapes if they are broadcast-able"""
# 1. Test broad cast case # 1. Test broadcast case
x = torch.rand(x_shape, dtype=self.dtype, device=self.device) x = torch.rand(x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(y_shape, dtype=self.dtype, device=self.device) y = torch.rand(y_shape, dtype=self.dtype, device=self.device)
out1 = F.fftconvolve(x, y) out1 = getattr(F, fn)(x, y)
# 2. Test without broadcast # 2. Test without broadcast
y_clone = y.expand(x_shape).clone() y_clone = y.expand(x_shape).clone()
assert y is not y_clone assert y is not y_clone
assert y_clone.shape == x.shape assert y_clone.shape == x.shape
out2 = F.fftconvolve(x, y_clone) out2 = getattr(F, fn)(x, y_clone)
# check that they are same # check that they are same
self.assertEqual(out1, out2) self.assertEqual(out1, out2)
...@@ -972,28 +968,23 @@ class Functional(TestBaseMixin): ...@@ -972,28 +968,23 @@ class Functional(TestBaseMixin):
(0, F.fftconvolve, (4, 3, 1, 2), (2, 2, 2)), (0, F.fftconvolve, (4, 3, 1, 2), (2, 2, 2)),
(0, F.fftconvolve, (1, ), (10, 4)), (0, F.fftconvolve, (1, ), (10, 4)),
(0, F.fftconvolve, (1, ), (2, 2, 2)), (0, F.fftconvolve, (1, ), (2, 2, 2)),
# incompatible shape except the last dim # non-broadcastable leading dimensions
(1, F.convolve, (5, 2, 3), (5, 3, 3)), (1, F.convolve, (5, 2, 3), (5, 3, 3)),
(1, F.convolve, (5, 2, 3), (5, 3, 4)), (1, F.convolve, (5, 2, 3), (5, 3, 4)),
(1, F.convolve, (5, 2, 3), (5, 3, 5)), (1, F.convolve, (5, 2, 3), (5, 3, 5)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 3)), (1, F.fftconvolve, (5, 2, 3), (5, 3, 3)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 4)), (1, F.fftconvolve, (5, 2, 3), (5, 3, 4)),
(2, F.fftconvolve, (5, 2, 3), (5, 3, 5)), (1, F.fftconvolve, (5, 2, 3), (5, 3, 5)),
# broadcast-able (only for convolve)
(1, F.convolve, (5, 2, 3), (5, 1, 3)),
(1, F.convolve, (5, 2, 3), (5, 1, 4)),
(1, F.convolve, (5, 2, 3), (5, 1, 5)),
# fmt: on # fmt: on
], ],
) )
def test_convolve_input_leading_dim_check(self, case, fn, x_shape, y_shape): def test_convolve_input_dim_check(self, case, fn, x_shape, y_shape):
"""Check that convolve properly rejects inputs with different leading dimensions.""" """Check that convolve properly rejects inputs with incompatible dimensions."""
x = torch.rand(*x_shape, dtype=self.dtype, device=self.device) x = torch.rand(*x_shape, dtype=self.dtype, device=self.device)
y = torch.rand(*y_shape, dtype=self.dtype, device=self.device) y = torch.rand(*y_shape, dtype=self.dtype, device=self.device)
message = [ message = [
"The operands must be the same dimension", "The operands must be the same dimension",
"Leading dimensions of x and y don't match",
"Leading dimensions of x and y are not broadcastable", "Leading dimensions of x and y are not broadcastable",
][case] ][case]
with self.assertRaisesRegex(ValueError, message): with self.assertRaisesRegex(ValueError, message):
......
...@@ -2295,19 +2295,16 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor: ...@@ -2295,19 +2295,16 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
return specgram_enhanced return specgram_enhanced
def _check_shape_compatible(x: torch.Tensor, y: torch.Tensor, allow_broadcast: bool) -> None: def _check_shape_compatible(x: torch.Tensor, y: torch.Tensor) -> None:
if x.ndim != y.ndim: if x.ndim != y.ndim:
raise ValueError(f"The operands must be the same dimension (got {x.ndim} and {y.ndim}).") raise ValueError(f"The operands must be the same dimension (got {x.ndim} and {y.ndim}).")
if not allow_broadcast:
if x.shape[:-1] != y.shape[:-1]: for i in range(x.ndim - 1):
raise ValueError(f"Leading dimensions of x and y don't match (got {x.shape} and {y.shape}).") xi = x.size(i)
else: yi = y.size(i)
for i in range(x.ndim - 1): if xi == yi or xi == 1 or yi == 1:
xi = x.size(i) continue
yi = y.size(i) raise ValueError(f"Leading dimensions of x and y are not broadcastable (got {x.shape} and {y.shape}).")
if xi == yi or xi == 1 or yi == 1:
continue
raise ValueError(f"Leading dimensions of x and y are not broadcastable (got {x.shape} and {y.shape}).")
def _check_convolve_mode(mode: str) -> None: def _check_convolve_mode(mode: str) -> None:
...@@ -2346,7 +2343,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T ...@@ -2346,7 +2343,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T
Args: Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must be broadcast-able to those of ``x``). (leading dimensions must be broadcast-able with those of ``x``).
mode (str, optional): Must be one of ("full", "valid", "same"). mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default) * "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
...@@ -2361,7 +2358,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T ...@@ -2361,7 +2358,7 @@ def fftconvolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.T
.. _convolution: .. _convolution:
https://en.wikipedia.org/wiki/Convolution https://en.wikipedia.org/wiki/Convolution
""" """
_check_shape_compatible(x, y, allow_broadcast=True) _check_shape_compatible(x, y)
_check_convolve_mode(mode) _check_convolve_mode(mode)
n = x.size(-1) + y.size(-1) - 1 n = x.size(-1) + y.size(-1) - 1
...@@ -2383,7 +2380,7 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens ...@@ -2383,7 +2380,7 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
Args: Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``). (leading dimensions must be broadcast-able with those of ``x``).
mode (str, optional): Must be one of ("full", "valid", "same"). mode (str, optional): Must be one of ("full", "valid", "same").
* "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default) * "full": Returns the full convolution result, with shape `(..., N + M - 1)`. (Default)
...@@ -2398,7 +2395,7 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens ...@@ -2398,7 +2395,7 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
.. _convolution: .. _convolution:
https://en.wikipedia.org/wiki/Convolution https://en.wikipedia.org/wiki/Convolution
""" """
_check_shape_compatible(x, y, allow_broadcast=False) _check_shape_compatible(x, y)
_check_convolve_mode(mode) _check_convolve_mode(mode)
x_size, y_size = x.size(-1), y.size(-1) x_size, y_size = x.size(-1), y.size(-1)
...@@ -2406,6 +2403,11 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens ...@@ -2406,6 +2403,11 @@ def convolve(x: torch.Tensor, y: torch.Tensor, mode: str = "full") -> torch.Tens
if x.size(-1) < y.size(-1): if x.size(-1) < y.size(-1):
x, y = y, x x, y = y, x
if x.shape[:-1] != y.shape[:-1]:
new_shape = [max(i, j) for i, j in zip(x.shape[:-1], y.shape[:-1])]
x = x.broadcast_to(new_shape + [x.shape[-1]])
y = y.broadcast_to(new_shape + [y.shape[-1]])
num_signals = torch.tensor(x.shape[:-1]).prod() num_signals = torch.tensor(x.shape[:-1]).prod()
reshaped_x = x.reshape((int(num_signals), x.size(-1))) reshaped_x = x.reshape((int(num_signals), x.size(-1)))
reshaped_y = y.reshape((int(num_signals), y.size(-1))) reshaped_y = y.reshape((int(num_signals), y.size(-1)))
......
...@@ -1844,7 +1844,7 @@ class Convolve(torch.nn.Module): ...@@ -1844,7 +1844,7 @@ class Convolve(torch.nn.Module):
Args: Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must match those of ``x``). (leading dimensions must be broadcast-able with those of ``x``).
Returns: Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
...@@ -1889,7 +1889,7 @@ class FFTConvolve(torch.nn.Module): ...@@ -1889,7 +1889,7 @@ class FFTConvolve(torch.nn.Module):
Args: Args:
x (torch.Tensor): First convolution operand, with shape `(..., N)`. x (torch.Tensor): First convolution operand, with shape `(..., N)`.
y (torch.Tensor): Second convolution operand, with shape `(..., M)` y (torch.Tensor): Second convolution operand, with shape `(..., M)`
(leading dimensions must be broadcast-able to those of ``x``). (leading dimensions must be broadcast-able with those of ``x``).
Returns: Returns:
torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where torch.Tensor: Result of convolving ``x`` and ``y``, with shape `(..., L)`, where
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment