Commit 94f4ef0f authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add timesteps visualization to CTC decoder tutorial (#2188)

Summary:
resulting tutorial: https://538358-90321822-gh.circle-artifacts.com/0/docs/tutorials/asr_inference_with_ctc_decoder_tutorial.html
- add visualization for timestep alignments
- modify section organization for decoder construction

Pull Request resolved: https://github.com/pytorch/audio/pull/2188

Reviewed By: mthrok

Differential Revision: D33954937

Pulled By: carolineechen

fbshipit-source-id: 8f397229d74c994b8793a30623e1de4c19ebd401
parent 612de66b
...@@ -47,8 +47,10 @@ using CTC loss. ...@@ -47,8 +47,10 @@ using CTC loss.
# #
import time import time
from typing import List
import IPython import IPython
import matplotlib.pyplot as plt
import torch import torch
import torchaudio import torchaudio
...@@ -173,9 +175,16 @@ torch.hub.download_url_to_file(kenlm_url, kenlm_file) ...@@ -173,9 +175,16 @@ torch.hub.download_url_to_file(kenlm_url, kenlm_file)
###################################################################### ######################################################################
# Construct Beam Search Decoder # Construct Decoders
# ----------------------------- # ------------------
# In this tutorial, we construct both a beam search decoder and a greedy decoder
# for comparison.
# #
######################################################################
# Beam Search Decoder
# ~~~~~~~~~~~~~~~~~~~
# The decoder can be constructed using the factory function # The decoder can be constructed using the factory function
# :py:func:`lexicon_decoder <torchaudio.prototype.ctc_decoder.lexicon_decoder>`. # :py:func:`lexicon_decoder <torchaudio.prototype.ctc_decoder.lexicon_decoder>`.
# In addition to the previously mentioned components, it also takes in various beam # In addition to the previously mentioned components, it also takes in various beam
...@@ -203,13 +212,10 @@ beam_search_decoder = lexicon_decoder( ...@@ -203,13 +212,10 @@ beam_search_decoder = lexicon_decoder(
###################################################################### ######################################################################
# Greedy Decoder # Greedy Decoder
# -------------- # ~~~~~~~~~~~~~~
#
# #
# For comparison against the beam search decoder, we also construct a
# basic greedy decoder.
# #
from typing import List
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
...@@ -241,7 +247,11 @@ greedy_decoder = GreedyCTCDecoder(tokens) ...@@ -241,7 +247,11 @@ greedy_decoder = GreedyCTCDecoder(tokens)
# ------------- # -------------
# #
# Now that we have the data, acoustic model, and decoder, we can perform # Now that we have the data, acoustic model, and decoder, we can perform
# inference. Recall the transcript corresponding to the waveform is # inference. The output of the beam search decoder is of type
# :py:func:`torchaudio.prototype.ctc_decoder.Hypothesis`, consisting of the
# predicted token IDs, corresponding words, hypothesis score, and timesteps
# corresponding to the token IDs. Recall the transcript corresponding to the
# waveform is
# :: # ::
# i really was very much afraid of showing him how much shocked i was at some parts of what he said # i really was very much afraid of showing him how much shocked i was at some parts of what he said
# #
...@@ -286,6 +296,51 @@ print(f"WER: {beam_search_wer}") ...@@ -286,6 +296,51 @@ print(f"WER: {beam_search_wer}")
# #
######################################################################
# Timestep Alignments
# -------------------
# Recall that one of the components of the resulting Hypotheses is timesteps
# corresponding to the token IDs.
#
timesteps = beam_search_result[0][0].timesteps
predicted_tokens = beam_search_decoder.idxs_to_tokens(beam_search_result[0][0].tokens)
print(predicted_tokens, len(predicted_tokens))
print(timesteps, timesteps.shape[0])
######################################################################
# Below, we visualize the token timestep alignments relative to the original waveform.
#
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
###################################################################### ######################################################################
# Beam Search Decoder Parameters # Beam Search Decoder Parameters
# ------------------------------ # ------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment