Commit cfa5a383 authored by Xiaohui Zhang's avatar Xiaohui Zhang Committed by Facebook GitHub Bot
Browse files

make sure inputs live on CPU for ctc decoder (#2289)

Summary:
Addressing the issue https://github.com/pytorch/audio/issues/2274:
Raise Runtime errors when the input tensors to the CTC decoder are GPU tensors since the CTC decoder only runs on CPU. Also update the data type check to use "raise" rather than "assert".

 ---
Pull Request resolved: https://github.com/pytorch/audio/pull/2289

Reviewed By: mthrok

Differential Revision: D35255630

Pulled By: xiaohui-zhang

fbshipit-source-id: d6c6e88d9ad4b9690bb741557fa9a9504e60872e
parent 03badcd3
......@@ -126,16 +126,24 @@ class LexiconDecoder:
List[List[torchaudio.prototype.ctc_decoder.Hypothesis]]
Args:
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model
lengths (Tensor or None, optional): tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model.
lengths (Tensor or None, optional): CPU tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch.
Returns:
List[List[Hypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
"""
assert emissions.dtype == torch.float32
if emissions.dtype != torch.float32:
raise ValueError("emissions must be float32.")
if emissions.is_cuda:
raise RuntimeError("emissions must be a CPU tensor.")
if lengths is not None and lengths.is_cuda:
raise RuntimeError("lengths must be a CPU tensor.")
B, T, N = emissions.size()
if lengths is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment