You need to sign in or sign up before continuing.
Commit a25bcb6b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add detail about CTC peaky behavior (#3566)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3566

Reviewed By: huangruizhe

Differential Revision: D48499338

Pulled By: mthrok

fbshipit-source-id: 7f837e1a1f8116d7d82411607c91628b729077d8
parent c5939616
@misc{zeyer2021does,
title={Why does CTC result in peaky behavior?},
author={Albert Zeyer and Ralf Schlüter and Hermann Ney},
year={2021},
eprint={2105.14849},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wavernn, @article{wavernn,
author = {Nal Kalchbrenner and author = {Nal Kalchbrenner and
Erich Elsen and Erich Elsen and
......
...@@ -119,8 +119,11 @@ for t in tokenized_transcript: ...@@ -119,8 +119,11 @@ for t in tokenized_transcript:
print() print()
###################################################################### ######################################################################
# Computing frame-level alignments # Computing alignments
# -------------------------------- # --------------------
#
# Frame-level alignments
# ~~~~~~~~~~~~~~~~~~~~~~
# #
# Now we call TorchAudio’s forced alignment API to compute the # Now we call TorchAudio’s forced alignment API to compute the
# frame-level alignment. For the detail of function signature, please # frame-level alignment. For the detail of function signature, please
...@@ -185,8 +188,8 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)): ...@@ -185,8 +188,8 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
# #
###################################################################### ######################################################################
# Obtain token-level alignment # Token-level alignments
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~
# #
# Next step is to resolve the repetation, so that each alignment does # Next step is to resolve the repetation, so that each alignment does
# not depend on previous alignments. # not depend on previous alignments.
...@@ -205,46 +208,12 @@ for s in token_spans: ...@@ -205,46 +208,12 @@ for s in token_spans:
###################################################################### ######################################################################
# Visualization # Word-level alignments
# ~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~
#
def plot_scores(spans, scores):
fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores")
span_xs, span_hs, span_ws = [], [], []
frame_xs, frame_hs = [], []
for span in spans:
if LABELS[span.token] != "|":
span_xs.append((span.end + span.start) / 2 + 0.4)
span_hs.append(span.score)
span_ws.append(span.end - span.start)
ax.annotate(LABELS[span.token], (span.start + 0.8, -0.07), weight="bold")
for t in range(span.start, span.end):
frame_xs.append(t + 1)
frame_hs.append(scores[t].item())
ax.bar(span_xs, span_hs, width=span_ws, color="gray", alpha=0.5, edgecolor="black")
ax.bar(frame_xs, frame_hs, width=0.5, alpha=0.5)
ax.set_ylim(-0.1, 1.1)
ax.grid(True, axis="y")
fig.tight_layout()
plot_scores(token_spans, alignment_scores)
######################################################################
# Obtain word-level alignments and confidence scores
# --------------------------------------------------
# #
# Now let’s merge the token-level alignments and confidence scores to get # Now let’s group the token-level alignments into word-level alignments.
# word-level alignments and confidence scores. Then, finally, we verify
# the quality of word alignments by 1) plotting the word-level alignments
# and the waveform, 2) segmenting the original audio according to the
# alignments and listening to them.
# Obtain word alignments from token alignments
def unflatten(list_, lengths): def unflatten(list_, lengths):
assert len(list_) == sum(lengths) assert len(list_) == sum(lengths)
i = 0 i = 0
...@@ -259,8 +228,8 @@ word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT]) ...@@ -259,8 +228,8 @@ word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
###################################################################### ######################################################################
# Visualization # Audio previews
# ~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~
# #
# Compute average score weighted by the span length # Compute average score weighted by the span length
...@@ -268,34 +237,6 @@ def _score(spans): ...@@ -268,34 +237,6 @@ def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans) return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for t_spans, chars in zip(token_spans, transcript):
t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
######################################################################
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate): def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames ratio = waveform.size(1) / num_frames
x0 = int(ratio * spans[0].start) x0 = int(ratio * spans[0].start)
...@@ -358,6 +299,114 @@ preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7]) ...@@ -358,6 +299,114 @@ preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8]) preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
######################################################################
# Visualization
# ~~~~~~~~~~~~~
#
# Now let's look at the alignment result and segment the original
# speech into words.
def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / emission.size(1) / sample_rate
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
axes[1].specgram(waveform[0], Fs=sample_rate)
for t_spans, chars in zip(token_spans, transcript):
t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
######################################################################
#
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
######################################################################
#
# Inconsistent treatment of ``blank`` token
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# When splitting the token-level alignments into words, you will
# notice that some blank tokens are treated differently, and this makes
# the interpretation of the result somehwat ambigious.
#
# This is easy to see when we plot the scores. The following figure
# shows word regions and non-word regions, with the frame-level scores
# of non-blank tokens.
def plot_scores(word_spans, scores):
fig, ax = plt.subplots()
span_xs, span_hs = [], []
ax.axvspan(word_spans[0][0].start -0.05, word_spans[-1][-1].end + 0.05, facecolor="paleturquoise", edgecolor="none", zorder=-1)
for t_span in word_spans:
for span in t_span:
for t in range(span.start, span.end):
span_xs.append(t + 0.5)
span_hs.append(scores[t].item())
ax.annotate(LABELS[span.token], (span.start, -0.07))
ax.axvspan(t_span[0].start -0.05, t_span[-1].end + 0.05, facecolor="mistyrose", edgecolor="none", zorder=-1)
ax.bar(span_xs, span_hs, color="lightsalmon", edgecolor="coral")
ax.set_title("Frame-level scores and word segments")
ax.set_ylim(-0.1, None)
ax.grid(True, axis="y")
ax.axhline(0, color="black")
fig.tight_layout()
plot_scores(word_spans, alignment_scores)
######################################################################
# In this plot, the blank tokens are those highlighted area without
# vertical bar.
# You can see that there are blank tokens which are interpreted as
# part of a word (highlighted red), while the others (highlighted blue)
# are not.
#
# One reason for this is because the model was trained without a
# label for the word boundary. The blank tokens are treated not just
# as repeatation but also as silence between words.
#
# But then, a question arises. Should frames immediately after or
# near the end of a word be silent or repeat?
#
# In the above example, if you go back to the previous plot of
# spectrogram and word regions, you see that after "y" in "curiosity",
# there is still some activities in multiple frequency buckets.
#
# Would it be more accurate if that frame was included in the word?
#
# Unfortunately, CTC does not provide a comprehensive solution to this.
# Models trained with CTC are known to exhibit "peaky" response,
# that is, they tend to spike for an aoccurance of a label, but the
# spike does not last for the duration of the label.
# (Note: Pre-trained Wav2Vec2 models tend to spike at the beginning of
# label occurances, but this not always the case.)
#
# :cite:`zeyer2021does` has in-depth alanysis on the peaky behavior of
# CTC.
# We encourage those who are interested understanding more to refer
# to the paper.
# The following is a quote from the paper, which is the exact issue we
# are facing here.
#
# *Peaky behavior can be problematic in certain cases,*
# *e.g. when an application requires to not use the blank label,*
# *e.g. to get meaningful time accurate alignments of phonemes*
# *to a transcription.*
###################################################################### ######################################################################
# Advanced: Handling transcripts with ``<star>`` token # Advanced: Handling transcripts with ``<star>`` token
...@@ -405,13 +454,15 @@ def compute_alignments(emission, transcript, dictionary): ...@@ -405,13 +454,15 @@ def compute_alignments(emission, transcript, dictionary):
###################################################################### ######################################################################
# **Original** # Full Transcript
# ~~~~~~~~~~~~~~~
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY) word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT) plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
###################################################################### ######################################################################
# **With <star> token** # Partial Transcript with ``<star>`` token
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# Now we replace the first part of the transcript with the ``<star>`` token. # Now we replace the first part of the transcript with the ``<star>`` token.
...@@ -435,10 +486,8 @@ preview_word(waveform, word_spans[1], num_frames, transcript[1]) ...@@ -435,10 +486,8 @@ preview_word(waveform, word_spans[1], num_frames, transcript[1])
preview_word(waveform, word_spans[2], num_frames, transcript[2]) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# # Partial Transcript without ``<star>`` token
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
######################################################################
# **Without <star> token**
# #
# As a comparison, the following aligns the partial transcript # As a comparison, the following aligns the partial transcript
# without using ``<star>`` token. # without using ``<star>`` token.
...@@ -466,4 +515,3 @@ plot_alignments(waveform, word_spans, emission, transcript) ...@@ -466,4 +515,3 @@ plot_alignments(waveform, word_spans, emission, transcript)
# Thanks to `Vineel Pratap <vineelkpratap@meta.com>`__ and `Zhaoheng # Thanks to `Vineel Pratap <vineelkpratap@meta.com>`__ and `Zhaoheng
# Ni <zni@meta.com>`__ for developing and open-sourcing the # Ni <zni@meta.com>`__ for developing and open-sourcing the
# forced aligner API. # forced aligner API.
#
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment