Commit 6ab1325a authored by Chin-Yun Yu's avatar Chin-Yun Yu Committed by Facebook GitHub Bot
Browse files

Fix contiguous error when backpropagating through lfilter (#3080)

Summary:
I encountered the following errors when using the filter with gradients being enabled.

```sh
Traceback (most recent call last):
  File "/home/ycy/working/audio/test_backward.py", line 20, in <module>
    loss.backward()
  File "/home/ycy/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/ycy/miniconda3/envs/nightly/lib/python3.10/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && padded_output_waveform.is_contiguous() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
```
This can happen if the outputs from lfilter was used by other operations.

### How to reproduce
The following script can reproduce the error on the stable and nightly versions.

```python
import torch
import torch.nn.functional as F
from torchaudio.functional import lfilter

a = torch.rand(250, 26, requires_grad=True)
b = torch.ones(250, 26, requires_grad=True)
x = torch.rand(250, 1024, requires_grad=True)
w = torch.eye(1024).unsqueeze(1)

y = lfilter(x, a, b, False)
y = F.conv_transpose1d(
    y.t().unsqueeze(0),
    w,
    stride=256,
).squeeze()
print(y.shape)
target = torch.ones_like(y)
loss = torch.nn.functional.mse_loss(y, target)
loss.backward()
```

### Cause

The inner call of differentiable IIR in the backward pass needs to ensure the input is contiguous. Adding a `contiguous()` call solve the problem.

Pull Request resolved: https://github.com/pytorch/audio/pull/3080

Reviewed By: xiaohui-zhang

Differential Revision: D43466612

Pulled By: mthrok

fbshipit-source-id: 375e0a147988656da47ac8397f7de6eae512a655
parent 5af309d3
......@@ -176,7 +176,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
}
if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2);
dx =
DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized)
.flip(2);
}
return {dx, da};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment