Commit 90e4959d authored by Guo Liyong's avatar Guo Liyong Committed by Facebook GitHub Bot
Browse files

Fix bug with unsqueezing length tensor in RNNTBeamSearch (#2344)

Summary:
This PR amends `RNNTBeamSearch`'s streaming decoding method to correctly unsqueeze `length` when its dimension is 0.

Original comment: Is "input.dim() == 0" unreachable as it could only be 2 or 3 in assertion of Line 329?

Pull Request resolved: https://github.com/pytorch/audio/pull/2344

Reviewed By: carolineechen, nateanl

Differential Revision: D35899740

Pulled By: hwangjeff

fbshipit-source-id: 84c1692b8cc9e5d35798d87f4a1bd052d94af9fb
parent 97ed428d
...@@ -333,8 +333,8 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -333,8 +333,8 @@ class RNNTBeamSearch(torch.nn.Module):
input = input.unsqueeze(0) input = input.unsqueeze(0)
assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)" assert length.shape == () or length.shape == (1,), "length must be of shape () or (1,)"
if input.dim() == 0: if length.dim() == 0:
input = input.unsqueeze(0) length = length.unsqueeze(0)
enc_out, _, state = self.model.transcribe_streaming(input, length, state) enc_out, _, state = self.model.transcribe_streaming(input, length, state)
return self._search(enc_out, hypothesis, beam_width), state return self._search(enc_out, hypothesis, beam_width), state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment