Commit 06b1cc9d authored by Daniel Walker's avatar Daniel Walker Committed by Facebook GitHub Bot
Browse files

Add precodition check for contiguous emissions tensor (#3074)

Summary:
This PR adds a precondition check to the `CTCDecoder` that raises a helpful exception when called on a noncontiguous emissions tensor.

Currently, noncontiguous tensors can be passed into the CTCDecoder, which in turn passes the tensors to the backing Flashlight C++ library and results in undefined behavior, since Flashlight requires the tensors to be laid out in contiguous memory. The following code demonstrates the problem:

```
import torch
from torchaudio.models.decoder import ctc_decoder

tokens = ['a', '-', '|']
decoder = ctc_decoder(lexicon=None, tokens=tokens)

emissions = torch.rand(len(tokens), 2)  # N x T contiguous
emissions = emissions.t()  # T x N noncontiguous

batch = emissions.unsqueeze(0)
result = decoder(batch)  # undefined behavior!!!
```

I stumbled on the issue accidentally when I noticed the decoder wasn't giving the expected results on my input only to realize, finally, that the tensor I had passed in was noncontiguous. In my case, Flashlight was iterating over unrelated segments of memory where it had expected to find a contiguous tensor. A precondition check will hopefully save others from making the same mistake.

Pull Request resolved: https://github.com/pytorch/audio/pull/3074

Reviewed By: nateanl, xiaohui-zhang

Differential Revision: D43376011

Pulled By: mthrok

fbshipit-source-id: 7c95aa8016d8f9f2d65b5b816a859b28ea4629f5
parent 85f8fc54
...@@ -305,6 +305,9 @@ class CTCDecoder: ...@@ -305,6 +305,9 @@ class CTCDecoder:
if emissions.is_cuda: if emissions.is_cuda:
raise RuntimeError("emissions must be a CPU tensor.") raise RuntimeError("emissions must be a CPU tensor.")
if not emissions.is_contiguous():
raise RuntimeError("emissions must be contiguous.")
if lengths is not None and lengths.is_cuda: if lengths is not None and lengths.is_cuda:
raise RuntimeError("lengths must be a CPU tensor.") raise RuntimeError("lengths must be a CPU tensor.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment