Commit 9fc0dcaa authored by Lakshmi Krishnan's avatar Lakshmi Krishnan Committed by Facebook GitHub Bot
Browse files

Improve RNN-T streaming decoding (#3295)

Summary:
This commit fixes the following issues affecting streaming decoding quality
1. The `init_b` hypothesis is only regenerated from blank token if no initial hypotheses are provided.
2. Allows the decoder to receive top-K hypothesis to continue decoding from, instead of using just the top hypothesis at each decoding step.  This dramatically affects decoding quality especially for speech with long pauses and disfluencies.
3. Some minor errors regarding shape checking for length.

This also means that the resulting output is the entire transcript up until that time step, instead of just the incremental change in transcript.

Pull Request resolved: https://github.com/pytorch/audio/pull/3295

Reviewed By: nateanl

Differential Revision: D46216113

Pulled By: hwangjeff

fbshipit-source-id: 8f7efae28dcca4a052f434ca55a2795c9e5ec0b0
parent c6624fa6
...@@ -65,9 +65,9 @@ def run_eval_streaming(args): ...@@ -65,9 +65,9 @@ def run_eval_streaming(args):
with torch.no_grad(): with torch.no_grad():
features, length = streaming_feature_extractor(segment) features, length = streaming_feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0] hypothesis = hypos
transcript = token_processor(hypothesis[0], lstrip=False) transcript = token_processor(hypos[0][0], lstrip=True)
print(transcript, end="", flush=True) print(transcript, end="\r", flush=True)
print() print()
# Non-streaming decode. # Non-streaming decode.
......
...@@ -39,6 +39,7 @@ to perform online speech recognition. ...@@ -39,6 +39,7 @@ to perform online speech recognition.
# -------------- # --------------
# #
import os
import torch import torch
import torchaudio import torchaudio
...@@ -222,9 +223,9 @@ def run_inference(num_iter=100): ...@@ -222,9 +223,9 @@ def run_inference(num_iter=100):
segment = cacher(chunk[:, 0]) segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment) features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0] hypothesis = hypos
transcript = token_processor(hypothesis[0], lstrip=False) transcript = token_processor(hypos[0][0], lstrip=False)
print(transcript, end="", flush=True) print(transcript, end="\r", flush=True)
chunks.append(chunk) chunks.append(chunk)
feats.append(features) feats.append(features)
......
...@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin): ...@@ -99,7 +99,7 @@ class RNNTBeamSearchTestImpl(TestBaseMixin):
self.assertEqual(res, scripted_res) self.assertEqual(res, scripted_res)
state = res[1] state = res[1]
hypo = res[0][0] hypo = res[0]
scripted_state = scripted_res[1] scripted_state = scripted_res[1]
scripted_hypo = scripted_res[0][0] scripted_hypo = scripted_res[0]
...@@ -109,11 +109,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -109,11 +109,7 @@ class RNNTBeamSearch(torch.nn.Module):
self.step_max_tokens = step_max_tokens self.step_max_tokens = step_max_tokens
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]: def _init_b_hypos(self, device: torch.device) -> List[Hypothesis]:
if hypo is not None:
token = _get_hypo_tokens(hypo)[-1]
state = _get_hypo_state(hypo)
else:
token = self.blank token = self.blank
state = None state = None
...@@ -230,14 +226,14 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -230,14 +226,14 @@ class RNNTBeamSearch(torch.nn.Module):
def _search( def _search(
self, self,
enc_out: torch.Tensor, enc_out: torch.Tensor,
hypo: Optional[Hypothesis], hypo: Optional[List[Hypothesis]],
beam_width: int, beam_width: int,
) -> List[Hypothesis]: ) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1] n_time_steps = enc_out.shape[1]
device = enc_out.device device = enc_out.device
a_hypos: List[Hypothesis] = [] a_hypos: List[Hypothesis] = []
b_hypos = self._init_b_hypos(hypo, device) b_hypos = self._init_b_hypos(device) if hypo is None else hypo
for t in range(n_time_steps): for t in range(n_time_steps):
a_hypos = b_hypos a_hypos = b_hypos
b_hypos = torch.jit.annotate(List[Hypothesis], []) b_hypos = torch.jit.annotate(List[Hypothesis], [])
...@@ -263,7 +259,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -263,7 +259,7 @@ class RNNTBeamSearch(torch.nn.Module):
if a_hypos: if a_hypos:
symbols_current_t += 1 symbols_current_t += 1
_, sorted_idx = torch.tensor([self.hypo_sort_key(hypo) for hypo in b_hypos]).topk(beam_width) _, sorted_idx = torch.tensor([self.hypo_sort_key(hyp) for hyp in b_hypos]).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx] b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos return b_hypos
...@@ -290,8 +286,8 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -290,8 +286,8 @@ class RNNTBeamSearch(torch.nn.Module):
if length.shape != () and length.shape != (1,): if length.shape != () and length.shape != (1,):
raise ValueError("length must be of shape () or (1,)") raise ValueError("length must be of shape () or (1,)")
if input.dim() == 0: if length.dim() == 0:
input = input.unsqueeze(0) length = length.unsqueeze(0)
enc_out, _ = self.model.transcribe(input, length) enc_out, _ = self.model.transcribe(input, length)
return self._search(enc_out, None, beam_width) return self._search(enc_out, None, beam_width)
...@@ -303,7 +299,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -303,7 +299,7 @@ class RNNTBeamSearch(torch.nn.Module):
length: torch.Tensor, length: torch.Tensor,
beam_width: int, beam_width: int,
state: Optional[List[List[torch.Tensor]]] = None, state: Optional[List[List[torch.Tensor]]] = None,
hypothesis: Optional[Hypothesis] = None, hypothesis: Optional[List[Hypothesis]] = None,
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]: ) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
r"""Performs beam search for the given input sequence in streaming mode. r"""Performs beam search for the given input sequence in streaming mode.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment