device_asr.py 11.3 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
"""
Device ASR with Emformer RNN-T
==============================

**Author** : `Moto Hira <moto@fb.com>`__, `Jeff Hwang <jeffhwang@fb.com>`__.

This tutorial shows how to use Emformer RNN-T and streaming API
to perform speech recognition on a streaming device input, i.e. microphone
on laptop.

.. note::

13
14
   This tutorial requires Streaming API, FFmpeg libraries (>=4.1, <5),
   and SentencePiece.
moto's avatar
moto committed
15

16
17
18
   The Streaming API is available in nightly build.
   Please refer to https://pytorch.org/get-started/locally
   for instructions.
moto's avatar
moto committed
19

20
   There are multiple ways to install FFmpeg libraries.
moto's avatar
moto committed
21
   If you are using Anaconda Python distribution,
22
23
   ``conda install 'ffmpeg<5'`` will install
   the required FFmpeg libraries.
moto's avatar
moto committed
24

25
26
   You can install SentencePiece by running ``pip install sentencepiece``.

moto's avatar
moto committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
.. note::

   This tutorial was tested on MacBook Pro and Dynabook with Windows 10.

   This tutorial does NOT work on Google Colab because the server running
   this tutorial does not have a microphone that you can talk to.
"""

######################################################################
# 1. Overview
# -----------
#
# We use streaming API to fetch audio from audio device (microphone)
# chunk by chunk, then run inference using Emformer RNN-T.
#
# For the basic usage of the streaming API and Emformer RNN-T
# please refer to
# `Media Stream API tutorial <./streaming_api_tutorial.html>`__ and
# `Online ASR with Emformer RNN-T <./online_asr_tutorial.html>`__.
#

######################################################################
# 2. Checking the supported devices
# ---------------------------------
#
# Firstly, we need to check the devices that Streaming API can access,
# and figure out the arguments (``src`` and ``format``) we need to pass
54
# to :py:func:`~torchaudio.io.StreamReader` class.
moto's avatar
moto committed
55
56
57
#
# We use ``ffmpeg`` command for this. ``ffmpeg`` abstracts away the
# difference of underlying hardware implementations, but the expected
moto's avatar
moto committed
58
# value for ``format`` varies across OS and each ``format`` defines
moto's avatar
moto committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# different syntax for ``src``.
#
# The details of supported ``format`` values and ``src`` syntax can
# be found in https://ffmpeg.org/ffmpeg-devices.html.
#
# For macOS, the following command will list the available devices.
#
# .. code::
#
#    $ ffmpeg -f avfoundation -list_devices true -i dummy
#    ...
#    [AVFoundation indev @ 0x126e049d0] AVFoundation video devices:
#    [AVFoundation indev @ 0x126e049d0] [0] FaceTime HD Camera
#    [AVFoundation indev @ 0x126e049d0] [1] Capture screen 0
#    [AVFoundation indev @ 0x126e049d0] AVFoundation audio devices:
#    [AVFoundation indev @ 0x126e049d0] [0] ZoomAudioDevice
#    [AVFoundation indev @ 0x126e049d0] [1] MacBook Pro Microphone
#
# We will use the following values for Streaming API.
#
# .. code::
#
81
#    StreamReader(
moto's avatar
moto committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#        src = ":1",  # no video, audio from device 1, "MacBook Pro Microphone"
#        format = "avfoundation",
#    )

######################################################################
#
# For Windows, ``dshow`` device should work.
#
# .. code::
#
#    > ffmpeg -f dshow -list_devices true -i dummy
#    ...
#    [dshow @ 000001adcabb02c0] DirectShow video devices (some may be both video and audio devices)
#    [dshow @ 000001adcabb02c0]  "TOSHIBA Web Camera - FHD"
#    [dshow @ 000001adcabb02c0]     Alternative name "@device_pnp_\\?\usb#vid_10f1&pid_1a42&mi_00#7&27d916e6&0&0000#{65e8773d-8f56-11d0-a3b9-00a0c9223196}\global"
#    [dshow @ 000001adcabb02c0] DirectShow audio devices
#    [dshow @ 000001adcabb02c0]  "... (Realtek High Definition Audio)"
#    [dshow @ 000001adcabb02c0]     Alternative name "@device_cm_{33D9A762-90C8-11D0-BD43-00A0C911CE86}\wave_{BF2B8AE1-10B8-4CA4-A0DC-D02E18A56177}"
#
# In the above case, the following value can be used to stream from microphone.
#
# .. code::
#
105
#    StreamReader(
moto's avatar
moto committed
106
107
108
109
110
111
112
113
114
#        src = "audio=@device_cm_{33D9A762-90C8-11D0-BD43-00A0C911CE86}\wave_{BF2B8AE1-10B8-4CA4-A0DC-D02E18A56177}",
#        format = "dshow",
#    )
#

######################################################################
# 3. Data acquisition
# -------------------
#
moto's avatar
moto committed
115
116
# Streaming audio from microphone input requires properly timing data
# acquisition. Failing to do so may introduce discontinuities in the
moto's avatar
moto committed
117
118
119
120
# data stream.
#
# For this reason, we will run the data acquisition in a subprocess.
#
moto's avatar
moto committed
121
# Firstly, we create a helper function that encapsulates the whole
moto's avatar
moto committed
122
123
# process executed in the subprocess.
#
moto's avatar
moto committed
124
125
# This function initializes the streaming API, acquires data then
# puts it in a queue, which the main process is watching.
moto's avatar
moto committed
126
127
128
129
130
131
132
133
134
135
136
137
138
#

import torch
import torchaudio


# The data acquisition process will stop after this number of steps.
# This eliminates the need of process synchronization and makes this
# tutorial simple.
NUM_ITER = 100


def stream(q, format, src, segment_length, sample_rate):
139
    from torchaudio.io import StreamReader
moto's avatar
moto committed
140

141
142
    print("Building StreamReader...")
    streamer = StreamReader(src, format=format)
moto's avatar
moto committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=sample_rate)

    print(streamer.get_src_stream_info(0))
    print(streamer.get_out_stream_info(0))

    print("Streaming...")
    print()
    stream_iterator = streamer.stream(timeout=-1, backoff=1.0)
    for _ in range(NUM_ITER):
        (chunk,) = next(stream_iterator)
        q.put(chunk)


######################################################################
#
# The notable difference from the non-device streaming is that,
# we provide ``timeout`` and ``backoff`` parameters to ``stream`` method.
#
moto's avatar
moto committed
161
162
163
164
# When acquiring data, if the rate of acquisition requests is higher
# than that at which the hardware can prepare the data, then
# the underlying implementation reports special error code, and expects
# client code to retry.
moto's avatar
moto committed
165
166
#
# Precise timing is the key for smooth streaming. Reporting this error
moto's avatar
moto committed
167
# from low-level implementation all the way back to Python layer,
moto's avatar
moto committed
168
169
170
171
172
173
174
# before retrying adds undesired overhead.
# For this reason, the retry behavior is implemented in C++ layer, and
# ``timeout`` and ``backoff`` parameters allow client code to control the
# behavior.
#
# For the detail of ``timeout`` and ``backoff`` parameters, please refer
# to the documentation of
175
# :py:meth:`~torchaudio.io.StreamReader.stream` method.
moto's avatar
moto committed
176
177
178
179
180
181
182
183
184
#
# .. note::
#
#    The proper value of ``backoff`` depends on the system configuration.
#    One way to see if ``backoff`` value is appropriate is to save the
#    series of acquired chunks as a continuous audio and listen to it.
#    If ``backoff`` value is too large, then the data stream is discontinuous.
#    The resulting audio sounds sped up.
#    If ``backoff`` value is too small or zero, the audio stream is fine,
moto's avatar
moto committed
185
#    but the data acquisition process enters busy-waiting state, and
moto's avatar
moto committed
186
187
188
189
190
191
192
193
194
#    this increases the CPU consumption.
#

######################################################################
# 4. Building inference pipeline
# ------------------------------
#
# The next step is to create components required for inference.
#
moto's avatar
moto committed
195
# This is the same process as
moto's avatar
moto committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# `Online ASR with Emformer RNN-T <./online_asr_tutorial.html>`__.
#


class Pipeline:
    """Build inference pipeline from RNNTBundle.

    Args:
        bundle (torchaudio.pipelines.RNNTBundle): Bundle object
        beam_width (int): Beam size of beam search decoder.
    """

    def __init__(self, bundle: torchaudio.pipelines.RNNTBundle, beam_width: int = 10):
        self.bundle = bundle
        self.feature_extractor = bundle.get_streaming_feature_extractor()
        self.decoder = bundle.get_decoder()
        self.token_processor = bundle.get_token_processor()

        self.beam_width = beam_width

        self.state = None
        self.hypothesis = None

    def infer(self, segment: torch.Tensor) -> str:
moto's avatar
moto committed
220
        """Perform streaming inference"""
moto's avatar
moto committed
221
222
223
224
225
        features, length = self.feature_extractor(segment)
        hypos, self.state = self.decoder.infer(
            features, length, self.beam_width, state=self.state, hypothesis=self.hypothesis
        )
        self.hypothesis = hypos[0]
226
        transcript = self.token_processor(self.hypothesis[0], lstrip=False)
moto's avatar
moto committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        return transcript


######################################################################
#


class ContextCacher:
    """Cache the end of input data and prepend the next input data with it.

    Args:
        segment_length (int): The size of main segment.
            If the incoming segment is shorter, then the segment is padded.
        context_length (int): The size of the context, cached and appended.
    """

    def __init__(self, segment_length: int, context_length: int):
        self.segment_length = segment_length
        self.context_length = context_length
        self.context = torch.zeros([context_length])

    def __call__(self, chunk: torch.Tensor):
        if chunk.size(0) < self.segment_length:
            chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0)))
        chunk_with_context = torch.cat((self.context, chunk))
        self.context = chunk[-self.context_length :]
        return chunk_with_context


######################################################################
# 5. The main process
# -------------------
#
moto's avatar
moto committed
260
# The execution flow of the main process is as follows:
moto's avatar
moto committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
#
# 1. Initialize the inference pipeline.
# 2. Launch data acquisition subprocess.
# 3. Run inference.
# 4. Clean up
#
# .. note::
#
#    As the data acquisition subprocess will be launched with `"spawn"`
#    method, all the code on global scope are executed on the subprocess
#    as well.
#
#    We want to instantiate pipeline only in the main process,
#    so we put them in a function and invoke it within
#    `__name__ == "__main__"` guard.
#


def main(device, src, bundle):
    print(torch.__version__)
    print(torchaudio.__version__)

    print("Building pipeline...")
    pipeline = Pipeline(bundle)

    sample_rate = bundle.sample_rate
    segment_length = bundle.segment_length * bundle.hop_length
    context_length = bundle.right_context_length * bundle.hop_length

    print(f"Sample rate: {sample_rate}")
    print(f"Main segment: {segment_length} frames ({segment_length / sample_rate} seconds)")
    print(f"Right context: {context_length} frames ({context_length / sample_rate} seconds)")

    cacher = ContextCacher(segment_length, context_length)

    @torch.inference_mode()
    def infer():
        for _ in range(NUM_ITER):
            chunk = q.get()
            segment = cacher(chunk[:, 0])
            transcript = pipeline.infer(segment)
            print(transcript, end="", flush=True)

    import torch.multiprocessing as mp

    ctx = mp.get_context("spawn")
    q = ctx.Queue()
    p = ctx.Process(target=stream, args=(q, device, src, segment_length, sample_rate))
    p.start()
    infer()
    p.join()


if __name__ == "__main__":
    main(
        device="avfoundation",
        src=":1",
        bundle=torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH,
    )

######################################################################
#
# .. code::
#
#    Building pipeline...
#    Sample rate: 16000
#    Main segment: 2560 frames (0.16 seconds)
#    Right context: 640 frames (0.04 seconds)
329
#    Building StreamReader...
moto's avatar
moto committed
330
331
332
333
334
335
#    SourceAudioStream(media_type='audio', codec='pcm_f32le', codec_long_name='PCM 32-bit floating point little-endian', format='flt', bit_rate=1536000, sample_rate=48000.0, num_channels=1)
#    OutputStream(source_index=0, filter_description='aresample=16000,aformat=sample_fmts=fltp')
#    Streaming...
#
#    hello world
#