Commit bc264256 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Misc tutorial updates (#3546)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3546

Reviewed By: huangruizhe

Differential Revision: D48219274

Pulled By: mthrok

fbshipit-source-id: 6881f039bf70cf7240fbcfeb48443471ef457bd4
parent 9f5fa84b
...@@ -13,7 +13,7 @@ This tutorial shows how to align transcripts to speech using ...@@ -13,7 +13,7 @@ This tutorial shows how to align transcripts to speech using
:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA :py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
implementations which are more performant than the vanilla Python implementations which are more performant than the vanilla Python
implementation above, and are more accurate. implementation above, and are more accurate.
It can also handle missing transcript with special <star> token. It can also handle missing transcript with special ``<star>`` token.
There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`,
which wraps the pre/post-processing explained in this tutorial and makes it easy which wraps the pre/post-processing explained in this tutorial and makes it easy
...@@ -23,6 +23,10 @@ to run forced-alignments. ...@@ -23,6 +23,10 @@ to run forced-alignments.
illustrate how to align non-English transcripts. illustrate how to align non-English transcripts.
""" """
######################################################################
# Preparation
# -----------
import torch import torch
import torchaudio import torchaudio
...@@ -36,9 +40,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -36,9 +40,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) print(device)
###################################################################### ######################################################################
# Preparation
# -----------
# #
import IPython import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -138,19 +141,34 @@ aligned_tokens, alignment_scores = align(emission, tokenized_transcript) ...@@ -138,19 +141,34 @@ aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
###################################################################### ######################################################################
# Now let's look at the output. # Now let's look at the output.
# Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)): for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}") print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
###################################################################### ######################################################################
# #
# The ``Frame`` instance represents the most likely token at each frame # .. note::
# with its confidence. #
# The alignment is expressed in the frame cordinate of the emission,
# which is different from the original waveform.
#
# It contains blank tokens and repeated tokens. The following is the
# interpretation of the non-blank tokens.
# #
# When interpreting it, one must remember that the meaning of blank token # .. code-block::
# and repeated token are context dependent. #
# 31: 0 [-], 1.00
# 32: 2 [i], 1.00 "i" starts and ends
# 33: 0 [-], 1.00
# 34: 0 [-], 1.00
# 35: 15 [h], 1.00 "h" starts
# 36: 15 [h], 0.93 "h" ends
# 37: 1 [a], 1.00 "a" starts and ends
# 38: 0 [-], 0.96
# 39: 0 [-], 1.00
# 40: 0 [-], 1.00
# 41: 13 [d], 1.00 "d" starts and ends
# 42: 0 [-], 1.00
# #
# .. note:: # .. note::
# #
...@@ -165,37 +183,16 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)): ...@@ -165,37 +183,16 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
# a - a b -> a a b # a - a b -> a a b
# ^^^ ^^^ # ^^^ ^^^
# #
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # "I" starts and ends
# 31: 0 [-], 0.98 #
# 32: 0 [-], 1.00 #
# 33: 1 [|], 0.85 # "|" (word boundary) starts
# 34: 1 [|], 1.00 # "|" ends
# 35: 0 [-], 0.61 #
# 36: 8 [H], 1.00 # "H" starts and ends
# 37: 0 [-], 1.00 #
# 38: 4 [A], 1.00 # "A" starts and ends
# 39: 0 [-], 0.99 #
# 40: 11 [D], 0.92 # "D" starts and ends
# 41: 0 [-], 0.93 #
# 42: 1 [|], 0.98 # "|" starts
# 43: 1 [|], 1.00 # "|" ends
# 44: 3 [T], 1.00 # "T" starts
# 45: 3 [T], 0.90 # "T" ends
# 46: 8 [H], 1.00 # "H" starts and ends
# 47: 0 [-], 1.00 #
###################################################################### ######################################################################
# Obtain token-level alignment # Obtain token-level alignment
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# Next step is to resolve the repetation. So that what alignment represents # Next step is to resolve the repetation, so that each alignment does
# do not depend on previous alignments. # not depend on previous alignments.
# From the outputs ``alignment``, we compute the following ``Span`` object, # :py:func:`torchaudio.functional.merge_tokens` computes the
# which explains what token (in transcript) is present at what time span. # :py:class:`~torchaudio.functional.TokenSpan` object, which represents
# which token from the transcript is present at what time span.
###################################################################### ######################################################################
# #
...@@ -206,12 +203,11 @@ print("Token\tTime\tScore") ...@@ -206,12 +203,11 @@ print("Token\tTime\tScore")
for s in token_spans: for s in token_spans:
print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}") print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
###################################################################### ######################################################################
# Visualization # Visualization
# ~~~~~~~~~~~~~ # ~~~~~~~~~~~~~
# #
def plot_scores(spans, scores): def plot_scores(spans, scores):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores") ax.set_title("frame-level and token-level confidence scores")
...@@ -241,8 +237,6 @@ plot_scores(token_spans, alignment_scores) ...@@ -241,8 +237,6 @@ plot_scores(token_spans, alignment_scores)
# Obtain word-level alignments and confidence scores # Obtain word-level alignments and confidence scores
# -------------------------------------------------- # --------------------------------------------------
# #
######################################################################
# Now let’s merge the token-level alignments and confidence scores to get # Now let’s merge the token-level alignments and confidence scores to get
# word-level alignments and confidence scores. Then, finally, we verify # word-level alignments and confidence scores. Then, finally, we verify
# the quality of word alignments by 1) plotting the word-level alignments # the quality of word alignments by 1) plotting the word-level alignments
......
...@@ -11,15 +11,23 @@ Recognition <https://arxiv.org/abs/2007.09127>`__. ...@@ -11,15 +11,23 @@ Recognition <https://arxiv.org/abs/2007.09127>`__.
.. note:: .. note::
The implementation in this tutorial is simplified for This tutorial was originally written to illustrate a usecase
educational purpose. for Wav2Vec2 pretrained model.
If you are looking to align your corpus, we recommend to use TorchAudio now has a set of APIs designed for forced alignment.
:py:func:`torchaudio.functional.forced_align`, which is more The `CTC forced alignment API tutorial
accurate and faster. <./ctc_forced_alignment_api_tutorial.html>`__ illustrates the
usage of :py:func:`torchaudio.functional.forced_align`, which is
the core API.
Please refer to `this tutorial <./ctc_forced_alignment_api_tutorial.html>`__ If you are looking to align your corpus, we recommend to use
for the detail of :py:func:`~torchaudio.functional.forced_align`. :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which combines
:py:func:`~torchaudio.functional.forced_align` and other support
functions with pre-trained model specifically trained for
forced-alignment. Please refer to the
`Forced alignment for multilingual data
<forced_alignment_for_multilingual_data_tutorial.html>`__ which
illustrates its usage.
""" """
import torch import torch
...@@ -102,11 +110,13 @@ print(labels) ...@@ -102,11 +110,13 @@ print(labels)
def plot(): def plot():
plt.imshow(emission.T) fig, ax = plt.subplots()
plt.colorbar() img = ax.imshow(emission.T)
plt.title("Frame-wise class probability") ax.set_title("Frame-wise class probability")
plt.xlabel("Time") ax.set_xlabel("Time")
plt.ylabel("Labels") ax.set_ylabel("Labels")
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
fig.tight_layout()
plot() plot()
...@@ -185,10 +195,12 @@ trellis = get_trellis(emission, tokens) ...@@ -185,10 +195,12 @@ trellis = get_trellis(emission, tokens)
def plot(): def plot():
plt.imshow(trellis.T, origin="lower") fig, ax = plt.subplots()
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5)) img = ax.imshow(trellis.T, origin="lower")
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.colorbar() ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
fig.tight_layout()
plot() plot()
...@@ -280,10 +292,11 @@ def plot_trellis_with_path(trellis, path): ...@@ -280,10 +292,11 @@ def plot_trellis_with_path(trellis, path):
for _, p in enumerate(path): for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float("nan") trellis_with_path[p.time_index, p.token_index] = float("nan")
plt.imshow(trellis_with_path.T, origin="lower") plt.imshow(trellis_with_path.T, origin="lower")
plt.title("The path found by backtracking")
plt.tight_layout()
plot_trellis_with_path(trellis, path) plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
###################################################################### ######################################################################
# Looking good. # Looking good.
...@@ -308,7 +321,7 @@ class Segment: ...@@ -308,7 +321,7 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property @property
def length(self): def length(self):
...@@ -427,7 +440,7 @@ for word in word_segments: ...@@ -427,7 +440,7 @@ for word in word_segments:
################################################################################ ################################################################################
# Visualization # Visualization
# ~~~~~~~~~~~~~ # ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform): def plot_alignments(trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
...@@ -436,12 +449,12 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -436,12 +449,12 @@ def plot_alignments(trellis, segments, word_segments, waveform):
fig, [ax1, ax2] = plt.subplots(2, 1) fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_facecolor("lightgray")
ax1.set_xticks([]) ax1.set_xticks([])
ax1.set_yticks([]) ax1.set_yticks([])
for word in word_segments: for word in word_segments:
ax1.axvline(word.start - 0.5) ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")
ax1.axvline(word.end - 0.5)
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
...@@ -449,23 +462,19 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -449,23 +462,19 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small") ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform # The original waveform
ratio = waveform.size(0) / trellis.size(0) ratio = waveform.size(0) / sample_rate / trellis.size(0)
ax2.plot(waveform) ax2.specgram(waveform, Fs=sample_rate)
for word in word_segments: for word in word_segments:
x0 = ratio * word.start x0 = ratio * word.start
x1 = ratio * word.end x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color="red") ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8)) ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9)) ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax2.set_xlabel("time [second]") ax2.set_xlabel("time [second]")
ax2.set_yticks([]) ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout() fig.tight_layout()
...@@ -482,9 +491,7 @@ plot_alignments( ...@@ -482,9 +491,7 @@ plot_alignments(
# ------------- # -------------
# #
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i): def display_segment(i):
ratio = waveform.size(1) / trellis.size(0) ratio = waveform.size(1) / trellis.size(0)
word = word_segments[i] word = word_segments[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment