Commit c5c4bbfd authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update online ASR tutorial (#2226)

Summary:
https://554729-90321822-gh.circle-artifacts.com/0/docs/tutorials/online_asr_tutorial.html

1. Add figure to explain the caching
2. Fix the initialization of stream iterator

Pull Request resolved: https://github.com/pytorch/audio/pull/2226

Reviewed By: carolineechen

Differential Revision: D34265971

Pulled By: mthrok

fbshipit-source-id: 243301e74c4040f4b8cd111b363e70da60e5dae4
parent 38569ef0
......@@ -72,7 +72,11 @@ token_processor = bundle.get_token_processor()
######################################################################
# Streaming inference works on input data with overlap.
# Emformer RNN-T expects right context like the following.
# Emformer RNN-T model treats the newest portion of the input data
# as the "right context" — a preview of future context.
# In each inference call, the model expects the main segment
# to start from this right context from the previous inference call.
# The following figure illustrates this.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_context.png
#
......@@ -115,14 +119,20 @@ print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))
######################################################################
# `Streamer` iterate the source media without overlap, so we make a
# helper structure that caches a chunk and return it with right context
# appended when the next chunk is given.
# As previously explained, Emformer RNN-T model expects input data with
# overlaps; however, `Streamer` iterates the source media without overlap,
# so we make a helper structure that caches a part of input data from
# `Streamer` as right context and then appends it to the next input data from
# `Streamer`.
#
# The following figure illustrates this.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_streamer_context.png
#
class ContextCacher:
"""Cache the previous chunk and combine it with the new chunk
"""Cache the end of input data and prepend the next input data with it.
Args:
segment_length (int): The size of main segment.
......@@ -166,12 +176,14 @@ state, hypothesis = None, None
# repeatedly.
#
stream_iterator = streamer.stream()
@torch.inference_mode()
def run_inference(num_iter=200):
global state, hypothesis
chunks = []
for i, (chunk,) in enumerate(streamer.stream(), start=1):
for i, (chunk,) in enumerate(stream_iterator, start=1):
segment = cacher(chunk[:, 0])
features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment