Commit c6f3d123 authored by Caroline Chen's avatar Caroline Chen
Browse files

Add custom lm example to decoder tutorial (#2762)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2762

Reviewed By: mthrok

Differential Revision: D40332603

Pulled By: carolineechen

fbshipit-source-id: 2de51265adc81b4728f4d6798d287bd2eccf5251
parent 284e8f50
...@@ -28,21 +28,22 @@ using CTC loss. ...@@ -28,21 +28,22 @@ using CTC loss.
# a more detailed algorithm can be found in this `blog # a more detailed algorithm can be found in this `blog
# <https://towardsdatascience.com/boosting-your-sequence-generation-performance-with-beam-search-language-model-decoding-74ee64de435a>`__. # <https://towardsdatascience.com/boosting-your-sequence-generation-performance-with-beam-search-language-model-decoding-74ee64de435a>`__.
# #
# Running ASR inference using a CTC Beam Search decoder with a KenLM # Running ASR inference using a CTC Beam Search decoder with a language
# language model and lexicon constraint requires the following components # model and lexicon constraint requires the following components
# #
# - Acoustic Model: model predicting phonetics from audio waveforms # - Acoustic Model: model predicting phonetics from audio waveforms
# - Tokens: the possible predicted tokens from the acoustic model # - Tokens: the possible predicted tokens from the acoustic model
# - Lexicon: mapping between possible words and their corresponding # - Lexicon: mapping between possible words and their corresponding
# tokens sequence # tokens sequence
# - KenLM: n-gram language model trained with the `KenLM # - Language Model (LM): n-gram language model trained with the `KenLM
# library <https://kheafield.com/code/kenlm/>`__ # library <https://kheafield.com/code/kenlm/>`__, or custom language
# model that inherits :py:class:`~torchaudio.models.decoder.CTCDecoderLM`
# #
###################################################################### ######################################################################
# Preparation # Acoustic Model and Set Up
# ----------- # -------------------------
# #
# First we import the necessary utilities and fetch the data that we are # First we import the necessary utilities and fetch the data that we are
# working with # working with
...@@ -66,8 +67,6 @@ from torchaudio.models.decoder import ctc_decoder ...@@ -66,8 +67,6 @@ from torchaudio.models.decoder import ctc_decoder
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
###################################################################### ######################################################################
# Acoustic Model and Data
# ~~~~~~~~~~~~~~~~~~~~~~~
# #
# We use the pretrained `Wav2Vec 2.0 <https://arxiv.org/abs/2006.11477>`__ # We use the pretrained `Wav2Vec 2.0 <https://arxiv.org/abs/2006.11477>`__
# Base model that is finetuned on 10 min of the `LibriSpeech # Base model that is finetuned on 10 min of the `LibriSpeech
...@@ -107,10 +106,10 @@ if sample_rate != bundle.sample_rate: ...@@ -107,10 +106,10 @@ if sample_rate != bundle.sample_rate:
###################################################################### ######################################################################
# Files and Data for Decoder # Files and Data for Decoder
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # --------------------------
# #
# Next, we load in our token, lexicon, and KenLM data, which are used by # Next, we load in our token, lexicon, and language model data, which are used
# the decoder to predict words from the acoustic model output. Pretrained # by the decoder to predict words from the acoustic model output. Pretrained
# files for the LibriSpeech dataset can be downloaded through torchaudio, # files for the LibriSpeech dataset can be downloaded through torchaudio,
# or the user can provide their own files. # or the user can provide their own files.
# #
...@@ -118,7 +117,7 @@ if sample_rate != bundle.sample_rate: ...@@ -118,7 +117,7 @@ if sample_rate != bundle.sample_rate:
###################################################################### ######################################################################
# Tokens # Tokens
# ^^^^^^ # ~~~~~~
# #
# The tokens are the possible symbols that the acoustic model can predict, # The tokens are the possible symbols that the acoustic model can predict,
# including the blank and silent symbols. It can either be passed in as a # including the blank and silent symbols. It can either be passed in as a
...@@ -141,7 +140,7 @@ print(tokens) ...@@ -141,7 +140,7 @@ print(tokens)
###################################################################### ######################################################################
# Lexicon # Lexicon
# ^^^^^^^ # ~~~~~~~
# #
# The lexicon is a mapping from words to their corresponding tokens # The lexicon is a mapping from words to their corresponding tokens
# sequence, and is used to restrict the search space of the decoder to # sequence, and is used to restrict the search space of the decoder to
...@@ -159,6 +158,24 @@ print(tokens) ...@@ -159,6 +158,24 @@ print(tokens)
# #
######################################################################
# Language Model
# ~~~~~~~~~~~~~~
#
# A language model can be used in decoding to improve the results, by
# factoring in a language model score that represents the likelihood of
# the sequence into the beam search computation. Below, we outline the
# different forms of language models that are supported for decoding.
#
######################################################################
# No Language Model
# ^^^^^^^^^^^^^^^^^
#
# To create a decoder instance without a language model, set `lm=None`
# when initializing the decoder.
#
###################################################################### ######################################################################
# KenLM # KenLM
# ^^^^^ # ^^^^^
...@@ -172,6 +189,52 @@ print(tokens) ...@@ -172,6 +189,52 @@ print(tokens)
# `LibriSpeech <http://www.openslr.org/11>`__. # `LibriSpeech <http://www.openslr.org/11>`__.
# #
######################################################################
# Custom Language Model
# ^^^^^^^^^^^^^^^^^^^^^
#
# Users can define their own custom language model in Python, whether
# it be a statistical or neural network language model, using
# :py:class:`~torchaudio.models.decoder.CTCDecoderLM` and
# :py:class:`~torchaudio.models.decoder.CTCDecoderLMState`.
#
# For instance, the following code creates a basic wrapper around a PyTorch
# ``torch.nn.Module`` language model.
#
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
class CustomLM(CTCDecoderLM):
"""Create a Python wrapper around `language_model` to feed to the decoder."""
def __init__(self, language_model: torch.nn.Module):
CTCDecoderLM.__init__(self)
self.language_model = language_model
self.sil = -1 # index for silent token in the language model
self.states = {}
language_model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.language_model(self.sil)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.language_model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.sil)
###################################################################### ######################################################################
# Downloading Pretrained Files # Downloading Pretrained Files
...@@ -229,8 +292,6 @@ beam_search_decoder = ctc_decoder( ...@@ -229,8 +292,6 @@ beam_search_decoder = ctc_decoder(
# Greedy Decoder # Greedy Decoder
# ~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~
# #
#
#
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
...@@ -280,7 +341,7 @@ emission, _ = acoustic_model(waveform) ...@@ -280,7 +341,7 @@ emission, _ = acoustic_model(waveform)
###################################################################### ######################################################################
# The greedy decoder give the following result. # The greedy decoder gives the following result.
# #
greedy_result = greedy_decoder(emission[0]) greedy_result = greedy_decoder(emission[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment