Commit eab8aa74 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Updating CTC FA tutorial (#3542)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3542

Reviewed By: huangruizhe

Differential Revision: D48166025

Pulled By: mthrok

fbshipit-source-id: 29fee7dbf08394993972ec2967f94ce9fcb1c853
parent f7ab406a
......@@ -2,50 +2,36 @@
CTC forced alignment API tutorial
=================================
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__
This tutorial shows how to align transcripts to speech using
:py:func:`torchaudio.functional.forced_align`
which was developed along the work of
`Scaling Speech Technology to 1,000+ Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__, `Moto Hira <moto@meta.com>`__
The forced alignment is a process to align transcript with speech.
We cover the basics of forced alignment in `Forced Alignment with
Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified
step-by-step Python implementations.
This tutorial shows how to align transcripts to speech using
:py:func:`torchaudio.functional.forced_align` which was developed along the work of
`Scaling Speech Technology to 1,000+ Languages
<https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
implementations which are more performant than the vanilla Python
implementation above, and are more accurate.
It can also handle missing transcript with special <star> token.
For examples of aligning multiple languages, please refer to
`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__.
There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`,
which wraps the pre/post-processing explained in this tutorial and makes it easy
to run forced-alignments.
`Forced alignment for multilingual data
<./forced_alignment_for_multilingual_data_tutorial.html>`__ uses this API to
illustrate how to align non-English transcripts.
"""
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
from dataclasses import dataclass
from typing import List
import IPython
import matplotlib.pyplot as plt
######################################################################
#
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
......@@ -53,16 +39,24 @@ print(device)
# Preparation
# -----------
#
import IPython
import matplotlib.pyplot as plt
import torchaudio.functional as F
######################################################################
# First we prepare the speech data and the transcript we area going
# to use.
#
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
waveform, _ = torchaudio.load(SPEECH_FILE)
TRANSCRIPT = "i had that curiosity beside me at this moment".split()
######################################################################
# Generating emissions and tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Generating emissions
# ~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`~torchaudio.functional.forced_align` takes emission and
# token sequences and outputs timestaps of the tokens and their scores.
......@@ -70,30 +64,26 @@ TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
# Emission reperesents the frame-wise probability distribution over
# tokens, and it can be obtained by passing waveform to an acoustic
# model.
# Tokens are numerical expression of transcripts. It can be obtained by
# simply mapping each character to the index of token list.
# The emission and the token sequences must be using the same set of tokens.
#
# We can use pre-trained Wav2Vec2 model to obtain emission from speech,
# and map transcript to tokens.
# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`,
# which bandles pre-trained model weights with associated labels.
# Tokens are numerical expression of transcripts. There are many ways to
# tokenize transcripts, but here, we simply map alphabets into integer,
# which is how labels were constructed when the acoustice model we are
# going to use was trained.
#
# We will use a pre-trained Wav2Vec2 model,
# :py:data:`torchaudio.pipelines.MMS_FA`, to obtain emission and tokenize
# the transcript.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model(with_star=False).to(device)
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emission, _ = model(waveform.to(device))
emission = torch.log_softmax(emission, dim=-1)
num_frames = emission.size(1)
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.cpu().T)
......@@ -106,20 +96,24 @@ def plot_emission(emission):
plot_emission(emission[0])
######################################################################
# Tokenize the transcript
# ~~~~~~~~~~~~~~~~~~~~~~~
#
# We create a dictionary, which maps each label into token.
labels = bundle.get_labels()
DICTIONARY = {c: i for i, c in enumerate(labels)}
LABELS = bundle.get_labels(star=None)
DICTIONARY = bundle.get_dict(star=None)
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
######################################################################
# converting transcript to tokens is as simple as
tokenized_transcript = [DICTIONARY[c] for c in TRANSCRIPT]
tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word]
print(" ".join(str(t) for t in tokenized_transcript))
for t in tokenized_transcript:
print(t, end=" ")
print()
######################################################################
# Computing frame-level alignments
......@@ -129,17 +123,11 @@ print(" ".join(str(t) for t in tokenized_transcript))
# frame-level alignment. For the detail of function signature, please
# refer to :py:func:`~torchaudio.functional.forced_align`.
#
#
def align(emission, tokens):
alignments, scores = forced_align(
emission,
targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
input_lengths=torch.tensor([emission.size(1)], device=emission.device),
target_lengths=torch.tensor([len(tokens)], device=emission.device),
blank=0,
)
targets = torch.tensor([tokens], dtype=torch.int32, device=device)
alignments, scores = F.forced_align(emission, targets, blank=0)
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
......@@ -154,7 +142,7 @@ aligned_tokens, alignment_scores = align(emission, tokenized_transcript)
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
print(f"{i:3d}:\t{ali:2d} [{labels[ali]}], {score:.2f}")
print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}")
######################################################################
#
......@@ -209,46 +197,14 @@ for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)):
# which explains what token (in transcript) is present at what time span.
@dataclass
class TokenSpan:
index: int # index of token in transcript
start: int # start time (inclusive)
end: int # end time (exclusive)
score: float
def __len__(self) -> int:
return self.end - self.start
######################################################################
#
def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]:
prev_token = blank
i = start = -1
spans = []
for t, token in enumerate(tokens):
if token != prev_token:
if prev_token != blank:
spans.append(TokenSpan(i, start, t, scores[start:t].mean().item()))
if token != blank:
i += 1
start = t
prev_token = token
if prev_token != blank:
spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item()))
return spans
######################################################################
#
token_spans = merge_tokens(aligned_tokens, alignment_scores)
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
print("Token\tTime\tScore")
for s in token_spans:
print(f"{TRANSCRIPT[s.index]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}")
######################################################################
# Visualization
......@@ -256,18 +212,17 @@ for s in token_spans:
#
def plot_scores(spans, scores, transcript):
def plot_scores(spans, scores):
fig, ax = plt.subplots()
ax.set_title("frame-level and token-level confidence scores")
span_xs, span_hs, span_ws = [], [], []
frame_xs, frame_hs = [], []
for span in spans:
token = transcript[span.index]
if token != "|":
if LABELS[span.token] != "|":
span_xs.append((span.end + span.start) / 2 + 0.4)
span_hs.append(span.score)
span_ws.append(span.end - span.start)
ax.annotate(token, (span.start + 0.8, -0.07), weight="bold")
ax.annotate(LABELS[span.token], (span.start + 0.8, -0.07), weight="bold")
for t in range(span.start, span.end):
frame_xs.append(t + 1)
frame_hs.append(scores[t].item())
......@@ -279,7 +234,7 @@ def plot_scores(spans, scores, transcript):
fig.tight_layout()
plot_scores(token_spans, alignment_scores, TRANSCRIPT)
plot_scores(token_spans, alignment_scores)
######################################################################
......@@ -295,30 +250,18 @@ plot_scores(token_spans, alignment_scores, TRANSCRIPT)
# alignments and listening to them.
@dataclass
class WordSpan:
token_spans: List[TokenSpan]
score: float
# Obtain word alignments from token alignments
def merge_words(token_spans, transcript, separator="|") -> List[WordSpan]:
def _score(t_spans):
return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans)
words = []
def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
for j, span in enumerate(token_spans):
if transcript[span.index] == separator:
words.append(WordSpan(token_spans[i:j], _score(token_spans[i:j])))
i = j + 1
if i < len(token_spans):
words.append(WordSpan(token_spans[i:], _score(token_spans[i:])))
return words
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
word_spans = merge_words(token_spans, TRANSCRIPT)
word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
######################################################################
......@@ -326,45 +269,50 @@ word_spans = merge_words(token_spans, TRANSCRIPT)
# ~~~~~~~~~~~~~
#
# Compute average score weighted by the span length
def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
def plot_alignments(waveform, word_spans, num_frames, transcript, sample_rate=bundle.sample_rate):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
ratio = waveform.size(1) / sample_rate / num_frames
for w_span in word_spans:
t_spans = w_span.token_spans
t0, t1 = t_spans[0].start, t_spans[-1].end
ax.axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / emission.size(1) / sample_rate
for span in t_spans:
token = transcript[span.index]
ax.annotate(token, (span.start * ratio, sample_rate * 0.53), annotation_clip=False)
fig, axes = plt.subplots(2, 1)
axes[0].imshow(emission[0].detach().cpu().T, aspect="auto")
axes[0].set_title("Emission")
axes[0].set_xticks([])
ax.set_xlabel("time [second]")
ax.set_xlim([0, None])
fig.tight_layout()
axes[1].specgram(waveform[0], Fs=sample_rate)
for t_spans, chars in zip(token_spans, transcript):
t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1
axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white")
axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white")
axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False)
for span, char in zip(t_spans, chars):
t0 = span.start * ratio
axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT)
axes[1].set_xlabel("time [second]")
axes[1].set_xlim([0, None])
fig.tight_layout()
######################################################################
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle.sample_rate):
######################################################################
def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / num_frames
t0 = word_span.token_spans[0].start
t1 = word_span.token_spans[-1].end
x0 = int(ratio * t0)
x1 = int(ratio * t1)
tokens = "".join(transcript[t.index] for t in word_span.token_spans)
print(f"{tokens} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
x0 = int(ratio * spans[0].start)
x1 = int(ratio * spans[-1].end)
print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec")
segment = waveform[:, x0:x1]
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
num_frames = emission.size(1)
######################################################################
# Generate the audio for each segment
......@@ -374,47 +322,47 @@ IPython.display.Audio(SPEECH_FILE)
######################################################################
#
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0])
######################################################################
#
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1])
######################################################################
#
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2])
######################################################################
#
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3])
######################################################################
#
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4])
######################################################################
#
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5])
######################################################################
#
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6])
######################################################################
#
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7])
######################################################################
#
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT)
preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8])
######################################################################
......@@ -442,7 +390,7 @@ DICTIONARY["*"] = len(DICTIONARY)
# corresponding to the ``<star>`` token.
#
star_dim = torch.zeros((1, num_frames, 1), device=device)
star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype)
emission = torch.cat((emission, star_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
......@@ -455,10 +403,10 @@ plot_emission(emission[0])
def compute_alignments(emission, transcript, dictionary):
tokens = [dictionary[c] for c in transcript]
tokens = [dictionary[char] for word in transcript for char in word]
alignment, scores = align(emission, tokens)
token_spans = merge_tokens(alignment, scores)
word_spans = merge_words(token_spans, transcript)
token_spans = F.merge_tokens(alignment, scores)
word_spans = unflatten(token_spans, [len(word) for word in transcript])
return word_spans
......@@ -466,26 +414,31 @@ def compute_alignments(emission, transcript, dictionary):
# **Original**
word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT)
plot_alignments(waveform, word_spans, emission, TRANSCRIPT)
######################################################################
# **With <star> token**
#
# Now we replace the first part of the transcript with the ``<star>`` token.
transcript = "*|THIS|MOMENT"
transcript = "* this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
#
preview_word(waveform, word_spans[0], num_frames, transcript[0])
######################################################################
#
preview_word(waveform, word_spans[1], num_frames, transcript)
preview_word(waveform, word_spans[1], num_frames, transcript[1])
######################################################################
#
preview_word(waveform, word_spans[2], num_frames, transcript)
preview_word(waveform, word_spans[2], num_frames, transcript[2])
######################################################################
#
......@@ -497,9 +450,9 @@ preview_word(waveform, word_spans[2], num_frames, transcript)
# without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
transcript = "THIS|MOMENT"
transcript = "this moment".split()
word_spans = compute_alignments(emission, transcript, DICTIONARY)
plot_alignments(waveform, word_spans, num_frames, transcript)
plot_alignments(waveform, word_spans, emission, transcript)
######################################################################
# Conclusion
......@@ -517,7 +470,6 @@ plot_alignments(waveform, word_spans, num_frames, transcript)
# ---------------
#
# Thanks to `Vineel Pratap <vineelkpratap@meta.com>`__ and `Zhaoheng
# Ni <zni@meta.com>`__ for working on the forced aligner API, and `Moto
# Hira <moto@meta.com>`__ for providing alignment merging and
# visualization utilities.
# Ni <zni@meta.com>`__ for developing and open-sourcing the
# forced aligner API.
#
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment