Commit 18601691 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update forced alignment tutorial (#3440)

Summary:
* Fix backtrack visualization (the cooridnate was off-by-one.)
* Add note about the simplification and the new align API
* Explicitly handle SOS and EOS

Pull Request resolved: https://github.com/pytorch/audio/pull/3440

Reviewed By: xiaohui-zhang

Differential Revision: D46761282

Pulled By: mthrok

fbshipit-source-id: b0b6c9754674e8e23543e9f002e29b55102c92f8
parent 406e9c8d
......@@ -9,6 +9,17 @@ This tutorial shows how to align transcript to speech with
`CTC-Segmentation of Large Corpora for German End-to-end Speech
Recognition <https://arxiv.org/abs/2007.09127>`__.
.. note::
The implementation in this tutorial is simplified for
educational purpose.
If you are looking to align your corpus, we recommend to use
:py:func:`torchaudio.functional.forced_align`, which is more
accurate and faster.
Please refer to `this tutorial <./ctc_forced_alignment_api_tutorial.html>`__
for the detail of :py:func:`~torchaudio.functional.forced_align`.
"""
import torch
......@@ -138,7 +149,9 @@ plt.show()
# [`distill.pub <https://distill.pub/2017/ctc/>`__])
#
transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
# We enclose the transcript with space tokens, which represent SOS and EOS.
transcript = "|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"
dictionary = {c: i for i, c in enumerate(labels)}
tokens = [dictionary[c] for c in transcript]
......@@ -149,21 +162,17 @@ def get_trellis(emission, tokens, blank_id=0):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, blank_id], 0)
trellis[0, -num_tokens:] = -float("inf")
trellis[-num_tokens:, 0] = float("inf")
trellis = torch.zeros((num_frame, num_tokens))
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
trellis[0, 1:] = -float("inf")
trellis[-num_tokens + 1 :, 0] = float("inf")
for t in range(num_frame):
for t in range(num_frame - 1):
trellis[t + 1, 1:] = torch.maximum(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
trellis[t, :-1] + emission[t, tokens[1:]],
)
return trellis
......@@ -173,8 +182,9 @@ trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis[1:, 1:].T, origin="lower")
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plt.show()
......@@ -214,38 +224,38 @@ class Point:
def backtrack(trellis, emission, tokens, blank_id=0):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
t, j = trellis.size(0) - 1, trellis.size(1) - 1
path = [Point(j, t, emission[t, blank_id].exp().item())]
while j > 0:
# Should not happen but just in case
assert t > 0
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append(Point(j - 1, t - 1, prob))
# 3. Update the token
# Frame-wise score of stay vs change
p_stay = emission[t - 1, blank_id]
p_change = emission[t - 1, tokens[j]]
# Context-aware score for stay vs change
stayed = trellis[t - 1, j] + p_stay
changed = trellis[t - 1, j - 1] + p_change
# Update position
t -= 1
if changed > stayed:
j -= 1
if j == 0:
break
else:
raise ValueError("Failed to align")
# Store the path with frame-wise probability.
prob = (p_change if changed > stayed else p_stay).exp().item()
path.append(Point(j, t, prob))
# Now j == 0, which means, it reached the SoS.
# Fill up the rest for the sake of visualization
while t > 0:
prob = emission[t - 1, blank_id].exp().item()
path.append(Point(j, t - 1, prob))
t -= 1
return path[::-1]
......@@ -262,7 +272,7 @@ def plot_trellis_with_path(trellis, path):
trellis_with_path = trellis.clone()
for _, p in enumerate(path):
trellis_with_path[p.time_index, p.token_index] = float("nan")
plt.imshow(trellis_with_path[1:, 1:].T, origin="lower")
plt.imshow(trellis_with_path.T, origin="lower")
plot_trellis_with_path(trellis, path)
......@@ -326,7 +336,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.set_title("Path, label and probability for each label")
......@@ -335,8 +345,8 @@ def plot_trellis_with_segments(trellis, segments, transcript):
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start + 0.7, i + 0.3), weight="bold")
ax1.annotate(f"{seg.score:.2f}", (seg.start - 0.3, i + 4.3))
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3))
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
......@@ -405,11 +415,11 @@ def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start + 1 : seg.end + 1, i + 1] = float("nan")
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
ax1.imshow(trellis_with_path[1:, 1:].T, origin="lower")
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.set_xticks([])
ax1.set_yticks([])
......@@ -419,11 +429,11 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i + 0.3))
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 4), fontsize=8)
ax1.annotate(seg.label, (seg.start, i - 0.7))
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8)
# The original waveform
ratio = waveform.size(0) / (trellis.size(0) - 1)
ratio = waveform.size(0) / trellis.size(0)
ax2.plot(waveform)
for word in word_segments:
x0 = ratio * word.start
......@@ -452,13 +462,15 @@ plt.show()
################################################################################
# Audio Samples
# -------------
#
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i):
ratio = waveform.size(1) / (trellis.size(0) - 1)
ratio = waveform.size(1) / trellis.size(0)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment