online_asr_tutorial.py 8.63 KB
Newer Older
1
2
3
4
"""
Online ASR with Emformer RNN-T
==============================

5
**Author**: `Jeff Hwang <jeffhwang@meta.com>`__, `Moto Hira <moto@meta.com>`__
6
7
8
9
10
11
12
13

This tutorial shows how to use Emformer RNN-T and streaming API
to perform online speech recognition.

"""

######################################################################
#
14
# .. note::
15
#
moto's avatar
moto committed
16
#    This tutorial requires FFmpeg libraries (>=4.1, <7) and SentencePiece.
17
#
18
#    There are multiple ways to install FFmpeg libraries.
19
#    If you are using Anaconda Python distribution,
moto's avatar
moto committed
20
21
#    ``conda install -c conda-forge 'ffmpeg<7'`` will install
#    compatible FFmpeg libraries.
22
#
23
#    You can install SentencePiece by running ``pip install sentencepiece``.
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
######################################################################
# 1. Overview
# -----------
#
# Performing online speech recognition is composed of the following steps
#
# 1. Build the inference pipeline
#    Emformer RNN-T is composed of three components: feature extractor,
#    decoder and token processor.
# 2. Format the waveform into chunks of expected sizes.
# 3. Pass data through the pipeline.

######################################################################
# 2. Preparation
# --------------
#

42
43
44
import torch
import torchaudio

45
46
47
48
print(torch.__version__)
print(torchaudio.__version__)

import IPython
49
import matplotlib.pyplot as plt
moto's avatar
moto committed
50
from torchaudio.io import StreamReader
51
52
53
54
55
56

######################################################################
# 3. Construct the pipeline
# -------------------------
#
# Pre-trained model weights and related pipeline components are
57
# bundled as :py:class:`torchaudio.pipelines.RNNTBundle`.
58
#
59
# We use :py:data:`torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH`,
60
61
62
63
64
65
66
67
68
69
70
# which is a Emformer RNN-T model trained on LibriSpeech dataset.
#

bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH

feature_extractor = bundle.get_streaming_feature_extractor()
decoder = bundle.get_decoder()
token_processor = bundle.get_token_processor()

######################################################################
# Streaming inference works on input data with overlap.
moto's avatar
moto committed
71
72
73
74
75
# Emformer RNN-T model treats the newest portion of the input data
# as the "right context" — a preview of future context.
# In each inference call, the model expects the main segment
# to start from this right context from the previous inference call.
# The following figure illustrates this.
76
77
78
79
80
81
82
83
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_context.png
#
# The size of main segment and right context, along with
# the expected sample rate can be retrieved from bundle.
#

sample_rate = bundle.sample_rate
84
85
segment_length = bundle.segment_length * bundle.hop_length
context_length = bundle.right_context_length * bundle.hop_length
86
87

print(f"Sample rate: {sample_rate}")
88
89
print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")
90
91
92
93
94

######################################################################
# 4. Configure the audio stream
# -----------------------------
#
95
# Next, we configure the input audio stream using :py:class:`torchaudio.io.StreamReader`.
96
97
#
# For the detail of this API, please refer to the
moto's avatar
moto committed
98
# `StreamReader Basic Usage <./streamreader_basic_tutorial.html>`__.
99
100
101
102
103
104
105
106
107
108
109
110
#

######################################################################
# The following audio file was originally published by LibriVox project,
# and it is in the public domain.
#
# https://librivox.org/great-pirate-stories-by-joseph-lewis-french/
#
# It was re-uploaded for the sake of the tutorial.
#
src = "https://download.pytorch.org/torchaudio/tutorial-assets/greatpiratestories_00_various.mp3"

111
streamer = StreamReader(src)
112
streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=bundle.sample_rate)
113
114
115
116
117

print(streamer.get_src_stream_info(0))
print(streamer.get_out_stream_info(0))

######################################################################
moto's avatar
moto committed
118
119
120
121
122
123
124
125
126
# As previously explained, Emformer RNN-T model expects input data with
# overlaps; however, `Streamer` iterates the source media without overlap,
# so we make a helper structure that caches a part of input data from
# `Streamer` as right context and then appends it to the next input data from
# `Streamer`.
#
# The following figure illustrates this.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/emformer_rnnt_streamer_context.png
127
128
129
130
#


class ContextCacher:
moto's avatar
moto committed
131
    """Cache the end of input data and prepend the next input data with it.
132
133

    Args:
134
135
136
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
137
138
    """

139
140
141
142
    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])
143
144

    def __call__(self, chunk: torch.Tensor):
145
146
147
148
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        return chunk_with_context


######################################################################
# 5. Run stream inference
# -----------------------
#
# Finally, we run the recognition.
#
# First, we initialize the stream iterator, context cacher, and
# state and hypothesis that are used by decoder to carry over the
# decoding state between inference calls.
#

163
cacher = ContextCacher(segment_length, context_length)
164
165
166
167
168
169
170
171
172
173
174

state, hypothesis = None, None

######################################################################
# Next we, run the inference.
#
# For the sake of better display, we create a helper function which
# processes the source stream up to the given times and call it
# repeatedly.
#

moto's avatar
moto committed
175
176
stream_iterator = streamer.stream()

177

178
179
180
181
182
183
def _plot(feats, num_iter, unit=25):
    unit_dur = segment_length / sample_rate * unit
    num_plots = num_iter // unit + (1 if num_iter % unit else 0)
    fig, axes = plt.subplots(num_plots, 1)
    t0 = 0
    for i, ax in enumerate(axes):
moto's avatar
moto committed
184
        feats_ = feats[i * unit : (i + 1) * unit]
185
186
187
        t1 = t0 + segment_length / sample_rate * len(feats_)
        feats_ = torch.cat([f[2:-2] for f in feats_])  # remove boundary effect and overlap
        ax.imshow(feats_.T, extent=[t0, t1, 0, 1], aspect="auto", origin="lower")
moto's avatar
moto committed
188
        ax.tick_params(which="both", left=False, labelleft=False)
189
190
191
192
193
194
        ax.set_xlim(t0, t0 + unit_dur)
        t0 = t1
    fig.suptitle("MelSpectrogram Feature")
    plt.tight_layout()


195
@torch.inference_mode()
196
def run_inference(num_iter=100):
197
198
    global state, hypothesis
    chunks = []
199
    feats = []
moto's avatar
moto committed
200
    for i, (chunk,) in enumerate(stream_iterator, start=1):
201
        segment = cacher(chunk[:, 0])
202
203
        features, length = feature_extractor(segment)
        hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
204
205
206
        hypothesis = hypos
        transcript = token_processor(hypos[0][0], lstrip=False)
        print(transcript, end="\r", flush=True)
207
208

        chunks.append(chunk)
209
        feats.append(features)
210
211
212
        if i == num_iter:
            break

213
214
    # Plot the features
    _plot(feats, num_iter)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    return IPython.display.Audio(torch.cat(chunks).T.numpy(), rate=bundle.sample_rate)


######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()
252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

######################################################################
#

run_inference()

283
284
285
######################################################################
#
# Tag: :obj:`torchaudio.io`