Commit ffbfe74a authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add parameter usage to CTC inference tutorial (#2141)

Summary:
Add explanation and demonstration of different beam search decoder parameters.
Additionally use a better sample audio file and load in with token list instead of tokens file.

Pull Request resolved: https://github.com/pytorch/audio/pull/2141

Reviewed By: mthrok

Differential Revision: D33463230

Pulled By: carolineechen

fbshipit-source-id: d3dd6452b03d4fc2e095d778189c66f7161e4c68
parent 565f8d41
...@@ -36,7 +36,7 @@ using CTC loss. ...@@ -36,7 +36,7 @@ using CTC loss.
# working with # working with
# #
import os import time
import IPython import IPython
import torch import torch
...@@ -50,7 +50,7 @@ import torchaudio ...@@ -50,7 +50,7 @@ import torchaudio
# We use the pretrained `Wav2Vec 2.0 <https://arxiv.org/abs/2006.11477>`__ # We use the pretrained `Wav2Vec 2.0 <https://arxiv.org/abs/2006.11477>`__
# Base model that is finetuned on 10 min of the `LibriSpeech # Base model that is finetuned on 10 min of the `LibriSpeech
# dataset <http://www.openslr.org/12>`__, which can be loaded in using # dataset <http://www.openslr.org/12>`__, which can be loaded in using
# py:func:`torchaudio.pipelines`. For more detail on running Wav2Vec 2.0 speech # :py:func:`torchaudio.pipelines`. For more detail on running Wav2Vec 2.0 speech
# recognition pipelines in torchaudio, please refer to `this # recognition pipelines in torchaudio, please refer to `this
# tutorial <https://pytorch.org/audio/main/tutorials/speech_recognition_pipeline_tutorial.html>`__. # tutorial <https://pytorch.org/audio/main/tutorials/speech_recognition_pipeline_tutorial.html>`__.
# #
...@@ -65,7 +65,7 @@ acoustic_model = bundle.get_model() ...@@ -65,7 +65,7 @@ acoustic_model = bundle.get_model()
hub_dir = torch.hub.get_dir() hub_dir = torch.hub.get_dir()
speech_url = "https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/8461-258277-0000.wav" speech_url = "https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/1688-142285-0007.wav"
speech_file = f"{hub_dir}/speech.wav" speech_file = f"{hub_dir}/speech.wav"
torch.hub.download_url_to_file(speech_url, speech_file) torch.hub.download_url_to_file(speech_url, speech_file)
...@@ -75,7 +75,8 @@ IPython.display.Audio(speech_file) ...@@ -75,7 +75,8 @@ IPython.display.Audio(speech_file)
###################################################################### ######################################################################
# The transcript corresponding to this audio file is # The transcript corresponding to this audio file is
# ``"when it was the seven hundred and eighteenth night"`` # ::
# i really was very much afraid of showing him how much shocked i was at some parts of what he said
# #
waveform, sample_rate = torchaudio.load(speech_file) waveform, sample_rate = torchaudio.load(speech_file)
...@@ -85,8 +86,8 @@ if sample_rate != bundle.sample_rate: ...@@ -85,8 +86,8 @@ if sample_rate != bundle.sample_rate:
###################################################################### ######################################################################
# Files for Decoder # Files and Data for Decoder
# ~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# Next, we load in our token, lexicon, and KenLM data, which are used by # Next, we load in our token, lexicon, and KenLM data, which are used by
# the decoder to predict words from the acoustic model output. # the decoder to predict words from the acoustic model output.
...@@ -101,7 +102,9 @@ if sample_rate != bundle.sample_rate: ...@@ -101,7 +102,9 @@ if sample_rate != bundle.sample_rate:
# ^^^^^^ # ^^^^^^
# #
# The tokens are the possible symbols that the acoustic model can predict, # The tokens are the possible symbols that the acoustic model can predict,
# including the blank and silent symbols. # including the blank and silent symbols. It can either be passed in as a
# file, where each line consists of the tokens corresponding to the same
# index, or as a list of tokens, each mapping to a unique index.
# #
# :: # ::
# #
...@@ -113,9 +116,8 @@ if sample_rate != bundle.sample_rate: ...@@ -113,9 +116,8 @@ if sample_rate != bundle.sample_rate:
# ... # ...
# #
token_url = "https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/tokens-w2v2.txt" tokens = [label.lower() for label in bundle.get_labels()]
token_file = f"{hub_dir}/token.txt" print(tokens)
torch.hub.download_url_to_file(token_url, token_file)
###################################################################### ######################################################################
...@@ -151,6 +153,9 @@ torch.hub.download_url_to_file(lexicon_url, lexicon_file) ...@@ -151,6 +153,9 @@ torch.hub.download_url_to_file(lexicon_url, lexicon_file)
# the binarized ``.bin`` LM can be used, but the binary format is # the binarized ``.bin`` LM can be used, but the binary format is
# recommended for faster loading. # recommended for faster loading.
# #
# The language model used in this tutorial is a 4-gram KenLM trained using
# `LibriSpeech <http://www.openslr.org/11>`__.
#
kenlm_url = "https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/4-gram-librispeech.bin" kenlm_url = "https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/4-gram-librispeech.bin"
kenlm_file = f"{hub_dir}/kenlm.bin" kenlm_file = f"{hub_dir}/kenlm.bin"
...@@ -161,26 +166,25 @@ torch.hub.download_url_to_file(kenlm_url, kenlm_file) ...@@ -161,26 +166,25 @@ torch.hub.download_url_to_file(kenlm_url, kenlm_file)
# Construct Beam Search Decoder # Construct Beam Search Decoder
# ----------------------------- # -----------------------------
# #
# The decoder can be constructed using the # The decoder can be constructed using the factory function
# :py:func:`torchaudio.prototype.ctc_decoder.kenlm_lexicon_decoder` # :py:func:`kenlm_lexicon_decoder <torchaudio.prototype.ctc_decoder.kenlm_lexicon_decoder>`.
# factory function. # In addition to the previously mentioned components, it also takes in various beam
# In addition to the previously mentioned components, it also takes in # search decoding parameters and token/word parameters.
# various beam search decoding parameters and token/word parameters.
# #
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = kenlm_lexicon_decoder( beam_search_decoder = kenlm_lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=token_file, tokens=tokens,
kenlm=kenlm_file, kenlm=kenlm_file,
nbest=1, nbest=3,
beam_size=1500, beam_size=1500,
beam_size_token=50, lm_weight=LM_WEIGHT,
lm_weight=3.23, word_score=WORD_SCORE,
word_score=-1.39,
unk_score=float("-inf"),
sil_score=0,
) )
...@@ -192,6 +196,8 @@ beam_search_decoder = kenlm_lexicon_decoder( ...@@ -192,6 +196,8 @@ beam_search_decoder = kenlm_lexicon_decoder(
# basic greedy decoder. # basic greedy decoder.
# #
from typing import List
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
def __init__(self, labels, blank=0): def __init__(self, labels, blank=0):
...@@ -199,21 +205,22 @@ class GreedyCTCDecoder(torch.nn.Module): ...@@ -199,21 +205,22 @@ class GreedyCTCDecoder(torch.nn.Module):
self.labels = labels self.labels = labels
self.blank = blank self.blank = blank
def forward(self, emission: torch.Tensor) -> str: def forward(self, emission: torch.Tensor) -> List[str]:
"""Given a sequence emission over labels, get the best path string """Given a sequence emission over labels, get the best path
Args: Args:
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
Returns: Returns:
str: The resulting transcript List[str]: The resulting transcript
""" """
indices = torch.argmax(emission, dim=-1) # [num_seq,] indices = torch.argmax(emission, dim=-1) # [num_seq,]
indices = torch.unique_consecutive(indices, dim=-1) indices = torch.unique_consecutive(indices, dim=-1)
indices = [i for i in indices if i != self.blank] indices = [i for i in indices if i != self.blank]
return "".join([self.labels[i] for i in indices]) joined = "".join([self.labels[i] for i in indices])
return joined.replace("|", " ").strip().split()
greedy_decoder = GreedyCTCDecoder(labels=bundle.get_labels()) greedy_decoder = GreedyCTCDecoder(tokens)
###################################################################### ######################################################################
...@@ -222,28 +229,210 @@ greedy_decoder = GreedyCTCDecoder(labels=bundle.get_labels()) ...@@ -222,28 +229,210 @@ greedy_decoder = GreedyCTCDecoder(labels=bundle.get_labels())
# #
# Now that we have the data, acoustic model, and decoder, we can perform # Now that we have the data, acoustic model, and decoder, we can perform
# inference. Recall the transcript corresponding to the waveform is # inference. Recall the transcript corresponding to the waveform is
# ``"when it was the seven hundred and eighteenth night"`` # ::
# i really was very much afraid of showing him how much shocked i was at some parts of what he said
# #
actual_transcript = "i really was very much afraid of showing him how much shocked i was at some parts of what he said"
actual_transcript = actual_transcript.split()
emission, _ = acoustic_model(waveform) emission, _ = acoustic_model(waveform)
######################################################################
# The greedy decoder give the following result.
#
greedy_result = greedy_decoder(emission[0])
greedy_transcript = greedy_result
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_transcript) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
###################################################################### ######################################################################
# Using the beam search decoder: # Using the beam search decoder:
#
beam_search_result = beam_search_decoder(emission) beam_search_result = beam_search_decoder(emission)
beam_search_transcript = " ".join(beam_search_result[0][0].words).lower().strip() beam_search_transcript = " ".join(beam_search_result[0][0].words).strip()
print(beam_search_transcript) beam_search_wer = torchaudio.functional.edit_distance(actual_transcript, beam_search_result[0][0].words) / len(
actual_transcript
)
print(f"Transcript: {beam_search_transcript}")
print(f"WER: {beam_search_wer}")
###################################################################### ######################################################################
# Using the greedy decoder: # We see that the transcript with the lexicon-constrained beam search
# decoder produces a more accurate result consisting of real words, while
# the greedy decoder can predict incorrectly spelled words like “affrayd”
# and “shoktd”.
#
greedy_result = greedy_decoder(emission[0])
greedy_transcript = greedy_result.replace("|", " ").lower().strip() ######################################################################
print(greedy_transcript) # Beam Search Decoder Parameters
# ------------------------------
#
# In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.kenlm_lexicon_decoder>`. # noqa
#
###################################################################### ######################################################################
# We see that the transcript with the lexicon-constrained beam search # Helper Function
# decoder consists of real words, while the greedy decoder can predict # ~~~~~~~~~~~~~~~
# incorrectly spelled words like “hundrad”. #
def print_decoded(decoder, emission, param, param_value):
start_time = time.monotonic()
result = decoder(emission)
decode_time = time.monotonic() - start_time
transcript = " ".join(result[0][0].words).lower().strip()
score = result[0][0].score
print(f"{param} {param_value:<3}: {transcript} (score: {score:.2f}; {decode_time:.4f} secs)")
######################################################################
# nbest
# ~~~~~
#
# This parameter indicates the number of best Hypothesis to return, which
# is a property that is not possible with the greedy decoder. For
# instance, by setting ``nbest=3`` when constructing the beam search
# decoder earlier, we can now access the hypotheses with the top 3 scores.
#
for i in range(3):
transcript = " ".join(beam_search_result[0][i].words).strip()
score = beam_search_result[0][i].score
print(f"{transcript} (score: {score})")
######################################################################
# beam size
# ~~~~~~~~~
#
# The ``beam_size`` parameter determines the maximum number of best
# hypotheses to hold after each decoding step. Using larger beam sizes
# allows for exploring a larger range of possible hypotheses which can
# produce hypotheses with higher scores, but it is computationally more
# expensive and does not provide additional gains beyond a certain point.
#
# In the example below, we see improvement in decoding quality as we
# increase beam size from 1 to 5 to 50, but notice how using a beam size
# of 500 provides the same output as beam size 50 while increase the
# computation time.
#
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
kenlm=kenlm_file,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size", beam_size)
######################################################################
# beam size token
# ~~~~~~~~~~~~~~~
#
# The ``beam_size_token`` parameter corresponds to the number of tokens to
# consider for expanding each hypothesis at the decoding step. Exploring a
# larger number of next possible tokens increases the range of potential
# hypotheses at the cost of computation.
#
num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
kenlm=kenlm_file,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam size token", beam_size_token)
######################################################################
# beam threshold
# ~~~~~~~~~~~~~~
#
# The ``beam_threshold`` parameter is used to prune the stored hypotheses
# set at each decoding step, removing hypotheses whose scores are greater
# than ``beam_threshold`` away from the highest scoring hypothesis. There
# is a balance between choosing smaller thresholds to prune more
# hypotheses and reduce the search space, and choosing a large enough
# threshold such that plausible hypotheses are not pruned.
#
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
kenlm=kenlm_file,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "beam threshold", beam_threshold)
######################################################################
# language model weight
# ~~~~~~~~~~~~~~~~~~~~~
#
# The ``lm_weight`` parameter is the weight to assign to the language
# model score which to accumulate with the acoustic model score for
# determining the overall scores. Larger weights encourage the model to
# predict next words based on the language model, while smaller weights
# give more weight to the acoustic model score instead.
#
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
kenlm=kenlm_file,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
print_decoded(beam_search_decoder, emission, "lm weight", lm_weight)
######################################################################
# additional parameters
# ~~~~~~~~~~~~~~~~~~~~~
#
# Additional parameters that can be optimized include the following
#
# - ``word_score``: score to add when word finishes
# - ``unk_score``: unknown word appearance score to add
# - ``sil_score``: silence appearance score to add
# - ``log_add``: whether to use log add for lexicon Trie smearing
# #
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment