"vscode:/vscode.git/clone" did not exist on "d48895bd53ffc90099beb76cf154c50a1ba23742"
Commit 7d37f69c authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Fix decoder call in Device ASR/AVSR tutorials (#3572)

Summary:
Fixes decoder calls and related code in Device ASR/AVSR tutorials to account for changes to RNN-T decoder introduced in https://github.com/pytorch/audio/issues/3295.

Pull Request resolved: https://github.com/pytorch/audio/pull/3572

Reviewed By: mthrok

Differential Revision: D48629428

Pulled By: hwangjeff

fbshipit-source-id: 63ede307fb4412aa28f88972d56dca8405607b7a
parent 6fbc1e68
......@@ -206,16 +206,15 @@ class Pipeline:
self.beam_width = beam_width
self.state = None
self.hypothesis = None
self.hypotheses = None
def infer(self, segment: torch.Tensor) -> str:
"""Perform streaming inference"""
features, length = self.feature_extractor(segment)
hypos, self.state = self.decoder.infer(
features, length, self.beam_width, state=self.state, hypothesis=self.hypothesis
self.hypotheses, self.state = self.decoder.infer(
features, length, self.beam_width, state=self.state, hypothesis=self.hypotheses
)
self.hypothesis = hypos[0]
transcript = self.token_processor(self.hypothesis[0], lstrip=False)
transcript = self.token_processor(self.hypotheses[0][0], lstrip=False)
return transcript
......@@ -291,7 +290,7 @@ def main(device, src, bundle):
chunk = q.get()
segment = cacher(chunk[:, 0])
transcript = pipeline.infer(segment)
print(transcript, end="", flush=True)
print(transcript, end="\r", flush=True)
import torch.multiprocessing as mp
......
......@@ -258,15 +258,14 @@ class InferencePipeline(torch.nn.Module):
self.token_processor = token_processor
self.state = None
self.hypothesis = None
self.hypotheses = None
def forward(self, audio, video):
audio, video = self.preprocessor(audio, video)
feats = self.model(audio.unsqueeze(0), video.unsqueeze(0))
length = torch.tensor([feats.size(1)], device=audio.device)
hypos, self.state = self.decoder.infer(feats, length, 10, state=self.state, hypothesis=self.hypothesis)
self.hypothesis = hypos[0]
transcript = self.token_processor(self.hypothesis[0], lstrip=False)
self.hypotheses, self.state = self.decoder.infer(feats, length, 10, state=self.state, hypothesis=self.hypotheses)
transcript = self.token_processor(self.hypotheses[0][0], lstrip=False)
return transcript
......@@ -370,7 +369,7 @@ def main(device, src, option=None):
video, audio = cacher(video, audio)
pipeline.state, pipeline.hypothesis = None, None
transcript = pipeline(audio, video.float())
print(transcript, end="", flush=True)
print(transcript, end="\r", flush=True)
num_video_frames = 0
video_chunks = []
audio_chunks = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment