Commit eab8aa74 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Updating CTC FA tutorial (#3542)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3542

Reviewed By: huangruizhe

Differential Revision: D48166025

Pulled By: mthrok

fbshipit-source-id: 29fee7dbf08394993972ec2967f94ce9fcb1c853
parent f7ab406a
...@@ -2,50 +2,36 @@ ...@@ -2,50 +2,36 @@
CTC forced alignment API tutorial CTC forced alignment API tutorial
================================= =================================
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__ **Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__, `Moto Hira <moto@meta.com>`__
This tutorial shows how to align transcripts to speech using
:py:func:`torchaudio.functional.forced_align`
which was developed along the work of
`Scaling Speech Technology to 1,000+ Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The forced alignment is a process to align transcript with speech. The forced alignment is a process to align transcript with speech.
We cover the basics of forced alignment in `Forced Alignment with This tutorial shows how to align transcripts to speech using
Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified :py:func:`torchaudio.functional.forced_align` which was developed along the work of
step-by-step Python implementations. `Scaling Speech Technology to 1,000+ Languages
<https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA :py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
implementations which are more performant than the vanilla Python implementations which are more performant than the vanilla Python
implementation above, and are more accurate. implementation above, and are more accurate.
It can also handle missing transcript with special <star> token. It can also handle missing transcript with special <star> token.
For examples of aligning multiple languages, please refer to There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`,
`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__. which wraps the pre/post-processing explained in this tutorial and makes it easy
to run forced-alignments.
`Forced alignment for multilingual data
<./forced_alignment_for_multilingual_data_tutorial.html>`__ uses this API to
illustrate how to align non-English transcripts.
""" """
import torch import torch
import torchaudio import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# #
from dataclasses import dataclass
from typing import List
import IPython
import matplotlib.pyplot as plt
######################################################################
#
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) print(device)
...@@ -53,16 +39,24 @@ print(device) ...@@ -53,16 +39,24 @@ print(device)
# Preparation # Preparation
# ----------- # -----------
# #
import IPython
import matplotlib.pyplot as plt
import torchaudio.functional as F
######################################################################
# First we prepare the speech data and the transcript we area going # First we prepare the speech data and the transcript we area going
# to use. # to use.
# #
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT" waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()
###################################################################### ######################################################################
# Generating emissions and tokens # Generating emissions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
# #
# :py:func:`~torchaudio.functional.forced_align` takes emission and # :py:func:`~torchaudio.functional.forced_align` takes emission and
# token sequences and outputs timestaps of the tokens and their scores. # token sequences and outputs timestaps of the tokens and their scores.
...@@ -70,30 +64,26 @@ TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT" ...@@ -70,30 +64,26 @@ TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
# Emission reperesents the frame-wise probability distribution over # Emission reperesents the frame-wise probability distribution over
# tokens, and it can be obtained by passing waveform to an acoustic # tokens, and it can be obtained by passing waveform to an acoustic
# model. # model.
# Tokens are numerical expression of transcripts. It can be obtained by
# simply mapping each character to the index of token list.
# The emission and the token sequences must be using the same set of tokens.
# #
# We can use pre-trained Wav2Vec2 model to obtain emission from speech, # Tokens are numerical expression of transcripts. There are many ways to
# and map transcript to tokens. # tokenize transcripts, but here, we simply map alphabets into integer,
# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`, # which is how labels were constructed when the acoustice model we are
# which bandles pre-trained model weights with associated labels. # going to use was trained.
#
# We will use a pre-trained Wav2Vec2 model,
# :py:data:`torchaudio.pipelines.MMS_FA`, to obtain emission and tokenize
# the transcript.
# #
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model().to(device)
model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode(): with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emission, _ = model(waveform.to(device)) emission, _ = model(waveform.to(device))
emission = torch.log_softmax(emission, dim=-1)
num_frames = emission.size(1)
###################################################################### ######################################################################
# #
def plot_emission(emission): def plot_emission(emission):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.imshow(emission.cpu().T) ax.imshow(emission.cpu().T)
...@@ -106,20 +96,24 @@ def plot_emission(emission): ...@@ -106,20 +96,24 @@ def plot_emission(emission):
plot_emission(emission[0]) plot_emission(emission[0])
###################################################################### ######################################################################
# Tokenize the transcript
# ~~~~~~~~~~~~~~~~~~~~~~~
#
# We create a dictionary, which maps each label into token. # We create a dictionary, which maps each label into token.
labels = bundle.get_labels() LABELS = bundle.get_labels(star=None)
DICTIONARY = {c: i for i, c in enumerate(labels)} DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items(): for k, v in DICTIONARY.items():
print(f"{k}: {v}") print(f"{k}: {v}")
###################################################################### ######################################################################
# converting transcript to tokens is as simple as # converting transcript to tokens is as simple as
tokenized_transcript = [DICTIONARY[c] for c in TRANSCRIPT] tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]
print(" ".join(str(t) for t in tokenized_transcript)) for t in tokenized_transcript:
print(t, end=" ")
print()
###################################################################### ######################################################################
# Computing frame-level alignments # Computing frame-level alignments
...@@ -129,17 +123,11 @@ print(" ".join(str(t) for t in tokenized_transcript)) ...@@ -129,17 +123,11 @@ print(" ".join(str(t) for t in tokenized_transcript))
# frame-level alignment. For the detail of function signature, please # frame-level alignment. For the detail of function signature, please
# refer to :py:func:`~torchaudio.functional.forced_align`. # refer to :py:func:`~torchaudio.functional.forced_align`.
# #
#
def align(emission, tokens): def align(emission, tokens):
alignments, scores = forced_align( targets = torch.tensor([tokens], dtype=torch.int32, device=device)
emission, alignments, scores = F.forced_align(emission, targets, blank=0)
targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
input_lengths=torch.tensor([emission.size(1)], device=emission.device),
target_lengths=torch.tensor([len(tokens)], device=emission.device),
blank=0,
)
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability scores = scores.exp() # convert back to probability
...@@ -154,7 +142,7 @@ aligned_tokens, alignment_scores = align(emission, tokenized_transcript) ...@@ -154,7 +142,7 @@ aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
# emission, which is different from the original waveform. # emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)): for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{labels[ali]}], {score:.2f}") print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
###################################################################### ######################################################################
# #
...@@ -209,46 +197,14 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)): ...@@ -209,46 +197,14 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
# which explains what token (in transcript) is present at what time span. # which explains what token (in transcript) is present at what time span.
@dataclass
class TokenSpan:
index: int # index of token in transcript
start: int # start time (inclusive)
end: int # end time (exclusive)
score: float
def __len__(self) -> int:
return self.end - self.start
###################################################################### ######################################################################
# #
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
######################################################################
#
token_spans = merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore") print("Token\tTime\tScore")
for s in token_spans: for s in token_spans:
print(f"{TRANSCRIPT[s.index]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}") print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
###################################################################### ######################################################################
# Visualization # Visualization
...@@ -256,18 +212,17 @@ for s in token_spans: ...@@ -256,18 +212,17 @@ for s in token_spans:
# #
def plot_scores(spans, scores, transcript): def plot_scores(spans, scores):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores") ax.set_title("frame-level and token-level confidence scores")
span_xs, span_hs, span_ws = [], [], [] span_xs, span_hs, span_ws = [], [], []
frame_xs, frame_hs = [], [] frame_xs, frame_hs = [], []
for span in spans: for span in spans:
token = transcript[span.index] if LABELS[span.token] != "|":
if token != "|":
span_xs.append((span.end + span.start) / 2 + 0.4) span_xs.append((span.end + span.start) / 2 + 0.4)
span_hs.append(span.score) span_hs.append(span.score)
span_ws.append(span.end - span.start) span_ws.append(span.end - span.start)
ax.annotate(token, (span.start + 0.8, -0.07), weight="bold") ax.annotate(LABELS[span.token], (span.start + 0.8, -0.07), weight="bold")
for t in range(span.start, span.end): for t in range(span.start, span.end):
frame_xs.append(t + 1) frame_xs.append(t + 1)
frame_hs.append(scores[t].item()) frame_hs.append(scores[t].item())
...@@ -279,7 +234,7 @@ def plot_scores(spans, scores, transcript): ...@@ -279,7 +234,7 @@ def plot_scores(spans, scores, transcript):
fig.tight_layout() fig.tight_layout()
plot_scores(token_spans, alignment_scores, TRANSCRIPT) plot_scores(token_spans, alignment_scores)
###################################################################### ######################################################################
...@@ -295,30 +250,18 @@ plot_scores(token_spans, alignment_scores, TRANSCRIPT) ...@@ -295,30 +250,18 @@ plot_scores(token_spans, alignment_scores, TRANSCRIPT)
# alignments and listening to them. # alignments and listening to them.
@dataclass
class WordSpan:
token_spans: List[TokenSpan]
score: float
# Obtain word alignments from token alignments # Obtain word alignments from token alignments
def merge_words(token_spans, transcript, separator="|") -> List[WordSpan]: def unflatten(list_, lengths):
def _score(t_spans): assert len(list_) == sum(lengths)
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
words = []
i = 0 i = 0
ret = []
for j, span in enumerate(token_spans): for l in lengths:
if transcript[span.index] == separator: ret.append(list_[i : i + l])
words.append(WordSpan(token_spans[i:j], _score(token_spans[i:j]))) i += l
i = j + 1 return ret
if i < len(token_spans):
words.append(WordSpan(token_spans[i:], _score(token_spans[i:])))
return words
word_spans = merge_words(token_spans, TRANSCRIPT) word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
###################################################################### ######################################################################
...@@ -326,45 +269,50 @@ word_spans = merge_words(token_spans, TRANSCRIPT) ...@@ -326,45 +269,50 @@ word_spans = merge_words(token_spans, TRANSCRIPT)
# ~~~~~~~~~~~~~ # ~~~~~~~~~~~~~
# #
# Compute average score weighted by the span length
def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
def plot_alignments(waveform, word_spans, num_frames, transcript, sample_rate=bundle.sample_rate):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate) def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / sample_rate / num_frames ratio = waveform.size(1) / emission.size(1) / sample_rate
for w_span in word_spans:
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
ax.axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span in t_spans: fig, axes = plt.subplots(2, 1)
token = transcript[span.index] axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
ax.annotate(token, (span.start * ratio, sample_rate * 0.53), annotation_clip=False) axes[0].set_title("Emission")
axes[0].set_xticks([])
ax.set_xlabel("time [second]") axes[1].specgram(waveform[0], Fs=sample_rate)
ax.set_xlim([0, None]) for t_spans, chars in zip(token_spans, transcript):
fig.tight_layout() t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT) axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
###################################################################### plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle.sample_rate): ######################################################################
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames ratio = waveform.size(1) / num_frames
t0 = word_span.token_spans[0].start x0 = int(ratio * spans[0].start)
t1 = word_span.token_spans[-1].end x1 = int(ratio * spans[-1].end)
x0 = int(ratio * t0) print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
x1 = int(ratio * t1)
tokens = "".join(transcript[t.index] for t in word_span.token_spans)
print(f"{tokens} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1] segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
num_frames = emission.size(1)
###################################################################### ######################################################################
# Generate the audio for each segment # Generate the audio for each segment
...@@ -374,47 +322,47 @@ IPython.display.Audio(SPEECH_FILE) ...@@ -374,47 +322,47 @@ IPython.display.Audio(SPEECH_FILE)
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT) preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
###################################################################### ######################################################################
...@@ -442,7 +390,7 @@ DICTIONARY["*"] = len(DICTIONARY) ...@@ -442,7 +390,7 @@ DICTIONARY["*"] = len(DICTIONARY)
# corresponding to the ``<star>`` token. # corresponding to the ``<star>`` token.
# #
star_dim = torch.zeros((1, num_frames, 1), device=device) star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2) emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2] assert len(DICTIONARY) == emission.shape[2]
...@@ -455,10 +403,10 @@ plot_emission(emission[0]) ...@@ -455,10 +403,10 @@ plot_emission(emission[0])
def compute_alignments(emission, transcript, dictionary): def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[c] for c in transcript] tokens = [dictionary[char] for word in transcript for char in word]
alignment, scores = align(emission, tokens) alignment, scores = align(emission, tokens)
token_spans = merge_tokens(alignment, scores) token_spans = F.merge_tokens(alignment, scores)
word_spans = merge_words(token_spans, transcript) word_spans = unflatten(token_spans, [len(word) for word in transcript])
return word_spans return word_spans
...@@ -466,26 +414,31 @@ def compute_alignments(emission, transcript, dictionary): ...@@ -466,26 +414,31 @@ def compute_alignments(emission, transcript, dictionary):
# **Original** # **Original**
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY) word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT) plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
###################################################################### ######################################################################
# **With <star> token** # **With <star> token**
# #
# Now we replace the first part of the transcript with the ``<star>`` token. # Now we replace the first part of the transcript with the ``<star>`` token.
transcript = "*|THIS|MOMENT" transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY) word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript) plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
preview_word(waveform, word_spans[0], num_frames, transcript[0])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[1], num_frames, transcript) preview_word(waveform, word_spans[1], num_frames, transcript[1])
###################################################################### ######################################################################
# #
preview_word(waveform, word_spans[2], num_frames, transcript) preview_word(waveform, word_spans[2], num_frames, transcript[2])
###################################################################### ######################################################################
# #
...@@ -497,9 +450,9 @@ preview_word(waveform, word_spans[2], num_frames, transcript) ...@@ -497,9 +450,9 @@ preview_word(waveform, word_spans[2], num_frames, transcript)
# without using ``<star>`` token. # without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors. # It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
transcript = "THIS|MOMENT" transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY) word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript) plot_alignments(waveform, word_spans, emission, transcript)
###################################################################### ######################################################################
# Conclusion # Conclusion
...@@ -517,7 +470,6 @@ plot_alignments(waveform, word_spans, num_frames, transcript) ...@@ -517,7 +470,6 @@ plot_alignments(waveform, word_spans, num_frames, transcript)
# --------------- # ---------------
# #
# Thanks to `Vineel Pratap <vineelkpratap@meta.com>`__ and `Zhaoheng # Thanks to `Vineel Pratap <vineelkpratap@meta.com>`__ and `Zhaoheng
# Ni <zni@meta.com>`__ for working on the forced aligner API, and `Moto # Ni <zni@meta.com>`__ for developing and open-sourcing the
# Hira <moto@meta.com>`__ for providing alignment merging and # forced aligner API.
# visualization utilities.
# #
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment