Commit 7d37f69c authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Fix decoder call in Device ASR/AVSR tutorials (#3572)

Summary:
Fixes decoder calls and related code in Device ASR/AVSR tutorials to account for changes to RNN-T decoder introduced in https://github.com/pytorch/audio/issues/3295.

Pull Request resolved: https://github.com/pytorch/audio/pull/3572

Reviewed By: mthrok

Differential Revision: D48629428

Pulled By: hwangjeff

fbshipit-source-id: 63ede307fb4412aa28f88972d56dca8405607b7a
parent 6fbc1e68
...@@ -206,16 +206,15 @@ class Pipeline: ...@@ -206,16 +206,15 @@ class Pipeline:
self.beam_width = beam_width self.beam_width = beam_width
self.state = None self.state = None
self.hypothesis = None self.hypotheses = None
def infer(self, segment: torch.Tensor) -> str: def infer(self, segment: torch.Tensor) -> str:
"""Perform streaming inference""" """Perform streaming inference"""
features, length = self.feature_extractor(segment) features, length = self.feature_extractor(segment)
hypos, self.state = self.decoder.infer( self.hypotheses, self.state = self.decoder.infer(
features, length, self.beam_width, state=self.state, hypothesis=self.hypothesis features, length, self.beam_width, state=self.state, hypothesis=self.hypotheses
) )
self.hypothesis = hypos[0] transcript = self.token_processor(self.hypotheses[0][0], lstrip=False)
transcript = self.token_processor(self.hypothesis[0], lstrip=False)
return transcript return transcript
...@@ -291,7 +290,7 @@ def main(device, src, bundle): ...@@ -291,7 +290,7 @@ def main(device, src, bundle):
chunk = q.get() chunk = q.get()
segment = cacher(chunk[:, 0]) segment = cacher(chunk[:, 0])
transcript = pipeline.infer(segment) transcript = pipeline.infer(segment)
print(transcript, end="", flush=True) print(transcript, end="\r", flush=True)
import torch.multiprocessing as mp import torch.multiprocessing as mp
......
...@@ -258,15 +258,14 @@ class InferencePipeline(torch.nn.Module): ...@@ -258,15 +258,14 @@ class InferencePipeline(torch.nn.Module):
self.token_processor = token_processor self.token_processor = token_processor
self.state = None self.state = None
self.hypothesis = None self.hypotheses = None
def forward(self, audio, video): def forward(self, audio, video):
audio, video = self.preprocessor(audio, video) audio, video = self.preprocessor(audio, video)
feats = self.model(audio.unsqueeze(0), video.unsqueeze(0)) feats = self.model(audio.unsqueeze(0), video.unsqueeze(0))
length = torch.tensor([feats.size(1)], device=audio.device) length = torch.tensor([feats.size(1)], device=audio.device)
hypos, self.state = self.decoder.infer(feats, length, 10, state=self.state, hypothesis=self.hypothesis) self.hypotheses, self.state = self.decoder.infer(feats, length, 10, state=self.state, hypothesis=self.hypotheses)
self.hypothesis = hypos[0] transcript = self.token_processor(self.hypotheses[0][0], lstrip=False)
transcript = self.token_processor(self.hypothesis[0], lstrip=False)
return transcript return transcript
...@@ -370,7 +369,7 @@ def main(device, src, option=None): ...@@ -370,7 +369,7 @@ def main(device, src, option=None):
video, audio = cacher(video, audio) video, audio = cacher(video, audio)
pipeline.state, pipeline.hypothesis = None, None pipeline.state, pipeline.hypothesis = None, None
transcript = pipeline(audio, video.float()) transcript = pipeline(audio, video.float())
print(transcript, end="", flush=True) print(transcript, end="\r", flush=True)
num_video_frames = 0 num_video_frames = 0
video_chunks = [] video_chunks = []
audio_chunks = [] audio_chunks = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment