Commit bc264256 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Misc tutorial updates (#3546)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3546

Reviewed By: huangruizhe

Differential Revision: D48219274

Pulled By: mthrok

fbshipit-source-id: 6881f039bf70cf7240fbcfeb48443471ef457bd4
parent 9f5fa84b
......@@ -13,7 +13,7 @@ This tutorial shows how to align transcripts to speech using
:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
implementations which are more performant than the vanilla Python
implementation above, and are more accurate.
It can also handle missing transcript with special <star> token.
It can also handle missing transcript with special ``<star>`` token.
There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`,
which wraps the pre/post-processing explained in this tutorial and makes it easy
......@@ -23,6 +23,10 @@ to run forced-alignments.
illustrate how to align non-English transcripts.
"""
######################################################################
# Preparation
# -----------
import torch
import torchaudio
......@@ -36,9 +40,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
######################################################################
# Preparation
# -----------
#
import IPython
import matplotlib.pyplot as plt
......@@ -138,19 +141,34 @@ aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
######################################################################
# Now let's look at the output.
# Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
######################################################################
#
# The ``Frame`` instance represents the most likely token at each frame
# with its confidence.
# .. note::
#
# The alignment is expressed in the frame cordinate of the emission,
# which is different from the original waveform.
#
# It contains blank tokens and repeated tokens. The following is the
# interpretation of the non-blank tokens.
#
# When interpreting it, one must remember that the meaning of blank token
# and repeated token are context dependent.
# .. code-block::
#
# 31: 0 [-], 1.00
# 32: 2 [i], 1.00 "i" starts and ends
# 33: 0 [-], 1.00
# 34: 0 [-], 1.00
# 35: 15 [h], 1.00 "h" starts
# 36: 15 [h], 0.93 "h" ends
# 37: 1 [a], 1.00 "a" starts and ends
# 38: 0 [-], 0.96
# 39: 0 [-], 1.00
# 40: 0 [-], 1.00
# 41: 13 [d], 1.00 "d" starts and ends
# 42: 0 [-], 1.00
#
# .. note::
#
......@@ -165,37 +183,16 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
# a - a b -> a a b
# ^^^ ^^^
#
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # "I" starts and ends
# 31: 0 [-], 0.98 #
# 32: 0 [-], 1.00 #
# 33: 1 [|], 0.85 # "|" (word boundary) starts
# 34: 1 [|], 1.00 # "|" ends
# 35: 0 [-], 0.61 #
# 36: 8 [H], 1.00 # "H" starts and ends
# 37: 0 [-], 1.00 #
# 38: 4 [A], 1.00 # "A" starts and ends
# 39: 0 [-], 0.99 #
# 40: 11 [D], 0.92 # "D" starts and ends
# 41: 0 [-], 0.93 #
# 42: 1 [|], 0.98 # "|" starts
# 43: 1 [|], 1.00 # "|" ends
# 44: 3 [T], 1.00 # "T" starts
# 45: 3 [T], 0.90 # "T" ends
# 46: 8 [H], 1.00 # "H" starts and ends
# 47: 0 [-], 1.00 #
######################################################################
# Obtain token-level alignment
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments.
# From the outputs ``alignment``, we compute the following ``Span`` object,
# which explains what token (in transcript) is present at what time span.
# Next step is to resolve the repetation, so that each alignment does
# not depend on previous alignments.
# :py:func:`torchaudio.functional.merge_tokens` computes the
# :py:class:`~torchaudio.functional.TokenSpan` object, which represents
# which token from the transcript is present at what time span.
######################################################################
#
......@@ -206,12 +203,11 @@ print("Token\tTime\tScore")
for s in token_spans:
print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
######################################################################
# Visualization
# ~~~~~~~~~~~~~
#
def plot_scores(spans, scores):
fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores")
......@@ -241,8 +237,6 @@ plot_scores(token_spans, alignment_scores)
# Obtain word-level alignments and confidence scores
# --------------------------------------------------
#
######################################################################
# Now let’s merge the token-level alignments and confidence scores to get
# word-level alignments and confidence scores. Then, finally, we verify
# the quality of word alignments by 1) plotting the word-level alignments
......
......@@ -11,15 +11,23 @@ Recognition <https://arxiv.org/abs/2007.09127>`__.
.. note::
The implementation in this tutorial is simplified for
educational purpose.
This tutorial was originally written to illustrate a usecase
for Wav2Vec2 pretrained model.
If you are looking to align your corpus, we recommend to use
:py:func:`torchaudio.functional.forced_align`, which is more
accurate and faster.
TorchAudio now has a set of APIs designed for forced alignment.
The `CTC forced alignment API tutorial
<./ctc_forced_alignment_api_tutorial.html>`__ illustrates the
usage of :py:func:`torchaudio.functional.forced_align`, which is
the core API.
Please refer to `this tutorial <./ctc_forced_alignment_api_tutorial.html>`__
for the detail of :py:func:`~torchaudio.functional.forced_align`.
If you are looking to align your corpus, we recommend to use
:py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which combines
:py:func:`~torchaudio.functional.forced_align` and other support
functions with pre-trained model specifically trained for
forced-alignment. Please refer to the
`Forced alignment for multilingual data
<forced_alignment_for_multilingual_data_tutorial.html>`__ which
illustrates its usage.
"""
import torch
......@@ -102,11 +110,13 @@ print(labels)
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
fig, ax = plt.subplots()
img = ax.imshow(emission.T)
ax.set_title("Frame-wise class probability")
ax.set_xlabel("Time")
ax.set_ylabel("Labels")
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
fig.tight_layout()
plot()
......@@ -185,10 +195,12 @@ trellis = get_trellis(emission, tokens)
def plot():
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
fig, ax = plt.subplots()
img = ax.imshow(trellis.T, origin="lower")
ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
fig.tight_layout()
plot()
......@@ -280,10 +292,11 @@ def plot_trellis_with_path(trellis, path):
for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float("nan")
plt.imshow(trellis_with_path.T, origin="lower")
plt.title("The path found by backtracking")
plt.tight_layout()
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
######################################################################
# Looking good.
......@@ -308,7 +321,7 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
......@@ -427,7 +440,7 @@ for word in word_segments:
################################################################################
# Visualization
# ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform):
def plot_alignments(trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
......@@ -436,12 +449,12 @@ def plot_alignments(trellis, segments, word_segments, waveform):
fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_facecolor("lightgray")
ax1.set_xticks([])
ax1.set_yticks([])
for word in word_segments:
ax1.axvline(word.start - 0.5)
ax1.axvline(word.end - 0.5)
ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
for i, seg in enumerate(segments):
if seg.label != "|":
......@@ -449,23 +462,19 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
ratio = waveform.size(0) / trellis.size(0)
ax2.plot(waveform)
ratio = waveform.size(0) / sample_rate / trellis.size(0)
ax2.specgram(waveform, Fs=sample_rate)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8))
ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9))
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
ax2.set_xlabel("time [second]")
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
......@@ -482,9 +491,7 @@ plot_alignments(
# -------------
#
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i):
ratio = waveform.size(1) / trellis.size(0)
word = word_segments[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment