Commit 9fc0dcaa authored by Lakshmi Krishnan's avatar Lakshmi Krishnan Committed by Facebook GitHub Bot
Browse files

Improve RNN-T streaming decoding (#3295)

Summary:
This commit fixes the following issues affecting streaming decoding quality
1. The `init_b` hypothesis is only regenerated from blank token if no initial hypotheses are provided.
2. Allows the decoder to receive top-K hypothesis to continue decoding from, instead of using just the top hypothesis at each decoding step.  This dramatically affects decoding quality especially for speech with long pauses and disfluencies.
3. Some minor errors regarding shape checking for length.

This also means that the resulting output is the entire transcript up until that time step, instead of just the incremental change in transcript.

Pull Request resolved: https://github.com/pytorch/audio/pull/3295

Reviewed By: nateanl

Differential Revision: D46216113

Pulled By: hwangjeff

fbshipit-source-id: 8f7efae28dcca4a052f434ca55a2795c9e5ec0b0
parent c6624fa6
......@@ -65,9 +65,9 @@ def run_eval_streaming(args):
with torch.no_grad():
features, length = streaming_feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0]
transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True)
hypothesis = hypos
transcript = token_processor(hypos[0][0], lstrip=True)
print(transcript, end="\r", flush=True)
print()
# Non-streaming decode.
......
......@@ -39,6 +39,7 @@ to perform online speech recognition.
# --------------
#
import os
import torch
import torchaudio
......@@ -222,9 +223,9 @@ def run_inference(num_iter=100):
segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0]
transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True)
hypothesis = hypos
transcript = token_processor(hypos[0][0], lstrip=False)
print(transcript, end="\r", flush=True)
chunks.append(chunk)
feats.append(features)
......
......@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
self.assertEqual(res, scripted_res)
state = res[1]
hypo = res[0][0]
hypo = res[0]
scripted_state = scripted_res[1]
scripted_hypo = scripted_res[0][0]
scripted_hypo = scripted_res[0]
......@@ -109,11 +109,7 @@ class RNNTBeamSearch(torch.nn.Module):
self.step_max_tokens = step_max_tokens
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
if hypo is not None:
token = _get_hypo_tokens(hypo)[-1]
state = _get_hypo_state(hypo)
else:
def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
token = self.blank
state = None
......@@ -230,14 +226,14 @@ class RNNTBeamSearch(torch.nn.Module):
def _search(
self,
enc_out: torch.Tensor,
hypo: Optional[Hypothesis],
hypo: Optional[List[Hypothesis]],
beam_width: int,
) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1]
device = enc_out.device
a_hypos: List[Hypothesis] = []
b_hypos = self._init_b_hypos(hypo, device)
b_hypos = self._init_b_hypos(device) if hypo is None else hypo
for t in range(n_time_steps):
a_hypos = b_hypos
b_hypos = torch.jit.annotate(List[Hypothesis], [])
......@@ -263,7 +259,7 @@ class RNNTBeamSearch(torch.nn.Module):
if a_hypos:
symbols_current_t += 1
_, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width)
_, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos
......@@ -290,8 +286,8 @@ class RNNTBeamSearch(torch.nn.Module):
if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)")
if input.dim() == 0:
input = input.unsqueeze(0)
if length.dim() == 0:
length = length.unsqueeze(0)
enc_out, _ = self.model.transcribe(input, length)
return self._search(enc_out, None, beam_width)
......@@ -303,7 +299,7 @@ class RNNTBeamSearch(torch.nn.Module):
length: torch.Tensor,
beam_width: int,
state: Optional[List[List[torch.Tensor]]] = None,
hypothesis: Optional[Hypothesis] = None,
hypothesis: Optional[List[Hypothesis]] = None,
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
r"""Performs beam search for the given input sequence in streaming mode.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment