Commit 612de66b authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Revise RNN-T pipeline streaming decoding logic (#2192)

Summary:
Rather than apply SentencePiece's `decode` to directly convert each hypothesis's token id sequence to an output string, we convert each token id sequence to word pieces and then manually join the word pieces ourselves. This allows us to preserve leading whitespaces on output strings and therefore account for word breaks and continuations across token processor invocations, which is particularly useful when performing streaming ASR.

https://user-images.githubusercontent.com/8345689/152093668-11fb775a-bf7b-4b1d-9516-9f8d5a9b6683.mov

Versus the previous behavior visualized in https://github.com/pytorch/audio/issues/2093, the scheme here properly constructs words comprising multiple pieces.

Pull Request resolved: https://github.com/pytorch/audio/pull/2192

Reviewed By: mthrok

Differential Revision: D33936622

Pulled By: hwangjeff

fbshipit-source-id: e550980c7d4cac9e982315508f793a6b816752e9
parent 7a3e262d
...@@ -41,9 +41,8 @@ def cli_main(): ...@@ -41,9 +41,8 @@ def cli_main():
features, length = streaming_feature_extractor(segment) features, length = streaming_feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0] hypothesis = hypos[0]
transcript = token_processor(hypothesis.tokens) transcript = token_processor(hypothesis.tokens, lstrip=False)
if transcript: print(transcript, end="", flush=True)
print(transcript, end=" ", flush=True)
print() print()
# Non-streaming decode. # Non-streaming decode.
......
...@@ -79,7 +79,7 @@ class _FeatureExtractor(ABC): ...@@ -79,7 +79,7 @@ class _FeatureExtractor(ABC):
class _TokenProcessor(ABC): class _TokenProcessor(ABC):
@abstractmethod @abstractmethod
def __call__(self, tokens: List[int]) -> str: def __call__(self, tokens: List[int], **kwargs) -> str:
"""Decodes given list of tokens to text sequence. """Decodes given list of tokens to text sequence.
Args: Args:
...@@ -140,11 +140,13 @@ class _SentencePieceTokenProcessor(_TokenProcessor): ...@@ -140,11 +140,13 @@ class _SentencePieceTokenProcessor(_TokenProcessor):
self.sp_model.pad_id(), self.sp_model.pad_id(),
} }
def __call__(self, tokens: List[int]) -> str: def __call__(self, tokens: List[int], lstrip: bool = True) -> str:
"""Decodes given list of tokens to text sequence. """Decodes given list of tokens to text sequence.
Args: Args:
tokens (List[int]): list of tokens to decode. tokens (List[int]): list of tokens to decode.
lstrip (bool, optional): if ``True``, returns text sequence with leading whitespace
removed. (Default: ``True``).
Returns: Returns:
str: str:
...@@ -153,7 +155,12 @@ class _SentencePieceTokenProcessor(_TokenProcessor): ...@@ -153,7 +155,12 @@ class _SentencePieceTokenProcessor(_TokenProcessor):
filtered_hypo_tokens = [ filtered_hypo_tokens = [
token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list
] ]
return self.sp_model.decode(filtered_hypo_tokens) output_string = "".join(self.sp_model.id_to_piece(filtered_hypo_tokens)).replace("\u2581", " ")
if lstrip:
return output_string.lstrip()
else:
return output_string
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment