Commit c6f3d123 authored by Caroline Chen's avatar Caroline Chen
Browse files

Add custom lm example to decoder tutorial (#2762)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2762

Reviewed By: mthrok

Differential Revision: D40332603

Pulled By: carolineechen

fbshipit-source-id: 2de51265adc81b4728f4d6798d287bd2eccf5251
parent 284e8f50
......@@ -28,21 +28,22 @@ using CTC loss.
# a more detailed algorithm can be found in this `blog
# <https://towardsdatascience.com/boosting-your-sequence-generation-performance-with-beam-search-language-model-decoding-74ee64de435a>`__.
#
# Running ASR inference using a CTC Beam Search decoder with a KenLM
# language model and lexicon constraint requires the following components
# Running ASR inference using a CTC Beam Search decoder with a language
# model and lexicon constraint requires the following components
#
# - Acoustic Model: model predicting phonetics from audio waveforms
# - Tokens: the possible predicted tokens from the acoustic model
# - Lexicon: mapping between possible words and their corresponding
# tokens sequence
# - KenLM: n-gram language model trained with the `KenLM
# library <https://kheafield.com/code/kenlm/>`__
# - Language Model (LM): n-gram language model trained with the `KenLM
# library <https://kheafield.com/code/kenlm/>`__, or custom language
# model that inherits :py:class:`~torchaudio.models.decoder.CTCDecoderLM`
#
######################################################################
# Preparation
# -----------
# Acoustic Model and Set Up
# -------------------------
#
# First we import the necessary utilities and fetch the data that we are
# working with
......@@ -66,8 +67,6 @@ from torchaudio.models.decoder import ctc_decoder
from torchaudio.utils import download_asset
######################################################################
# Acoustic Model and Data
# ~~~~~~~~~~~~~~~~~~~~~~~
#
# We use the pretrained `Wav2Vec 2.0 <https://arxiv.org/abs/2006.11477>`__
# Base model that is finetuned on 10 min of the `LibriSpeech
......@@ -107,10 +106,10 @@ if sample_rate != bundle.sample_rate:
######################################################################
# Files and Data for Decoder
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# --------------------------
#
# Next, we load in our token, lexicon, and KenLM data, which are used by
# the decoder to predict words from the acoustic model output. Pretrained
# Next, we load in our token, lexicon, and language model data, which are used
# by the decoder to predict words from the acoustic model output. Pretrained
# files for the LibriSpeech dataset can be downloaded through torchaudio,
# or the user can provide their own files.
#
......@@ -118,7 +117,7 @@ if sample_rate != bundle.sample_rate:
######################################################################
# Tokens
# ^^^^^^
# ~~~~~~
#
# The tokens are the possible symbols that the acoustic model can predict,
# including the blank and silent symbols. It can either be passed in as a
......@@ -141,7 +140,7 @@ print(tokens)
######################################################################
# Lexicon
# ^^^^^^^
# ~~~~~~~
#
# The lexicon is a mapping from words to their corresponding tokens
# sequence, and is used to restrict the search space of the decoder to
......@@ -159,6 +158,24 @@ print(tokens)
#
######################################################################
# Language Model
# ~~~~~~~~~~~~~~
#
# A language model can be used in decoding to improve the results, by
# factoring in a language model score that represents the likelihood of
# the sequence into the beam search computation. Below, we outline the
# different forms of language models that are supported for decoding.
#
######################################################################
# No Language Model
# ^^^^^^^^^^^^^^^^^
#
# To create a decoder instance without a language model, set `lm=None`
# when initializing the decoder.
#
######################################################################
# KenLM
# ^^^^^
......@@ -172,6 +189,52 @@ print(tokens)
# `LibriSpeech <http://www.openslr.org/11>`__.
#
######################################################################
# Custom Language Model
# ^^^^^^^^^^^^^^^^^^^^^
#
# Users can define their own custom language model in Python, whether
# it be a statistical or neural network language model, using
# :py:class:`~torchaudio.models.decoder.CTCDecoderLM` and
# :py:class:`~torchaudio.models.decoder.CTCDecoderLMState`.
#
# For instance, the following code creates a basic wrapper around a PyTorch
# ``torch.nn.Module`` language model.
#
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self)
self.language_model = language_model
self.sil = -1 # index for silent token in the language model
self.states = {}
language_model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.language_model(self.sil)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.language_model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.sil)
######################################################################
# Downloading Pretrained Files
......@@ -229,8 +292,6 @@ beam_search_decoder = ctc_decoder(
# Greedy Decoder
# ~~~~~~~~~~~~~~
#
#
#
class GreedyCTCDecoder(torch.nn.Module):
......@@ -280,7 +341,7 @@ emission, _ = acoustic_model(waveform)
######################################################################
# The greedy decoder give the following result.
# The greedy decoder gives the following result.
#
greedy_result = greedy_decoder(emission[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment