Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
......@@ -127,6 +127,22 @@ def _get_pattern():
return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = {
"examples_dirs": [
"../../examples/tutorials",
......@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True,
"first_notebook_cell": None,
"doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
}
autosummary_generate = True
......
......@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial
tutorials/device_asr
tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial
......@@ -147,6 +147,13 @@ Tutorials
.. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem::
:header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
......@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
#
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True)
......@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4):
axes[i].grid(True)
pos = axes[2].get_position()
plt.tight_layout()
fig.tight_layout()
if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
......@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
#
......@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Square wave
......@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Triangle wave
......@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
#
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Inharmonic Paritials
......@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
######################################################################
#
......@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE)
plot(freq, amp, waveform, SAMPLE_RATE)
######################################################################
# References
......
......@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
#
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
t = torch.arange(waveform.size(0)) / sample_rate
ratio = waveform.size(0) / emission.size(1) / sample_rate
chars = []
words = []
word_start = None
for token, timestep in zip(tokens, timesteps * ratio):
if token == "|":
if word_start is not None:
words.append((word_start, timestep))
word_start = None
else:
chars.append((token, timestep))
if word_start is None:
word_start = timestep
fig, axes = plt.subplots(3, 1)
def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
######################################################################
......
......@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......
# -*- coding: utf-8 -*-
"""
Audio Datasets
==============
......@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets.
"""
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch
import torchaudio
......@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os
import IPython
import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
......@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
figure.tight_layout()
######################################################################
......@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset.
#
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]:
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate)
######################################################################
#
i = 1
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
......@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# Preparation
# -----------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
......@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
######################################################################
......@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None)
stretch = T.TimeStretch()
rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
spec_12 = stretch(spec, overriding_rate=1.2)
spec_09 = stretch(spec, overriding_rate=0.9)
######################################################################
# TimeMasking
# -----------
#
torch.random.manual_seed(4)
spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original")
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis")
plot()
######################################################################
# FrequencyMasking
# ----------------
# Time and Frequency Masking
# --------------------------
#
torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original")
time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis")
plot()
......@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True)
ax.set_xlim([0, time_axis[-1]])
ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
......@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None):
......@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
plt.show(block=False)
######################################################################
......@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
......@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform")
plt.show(block=False)
######################################################################
......@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title)
plt.show(block=False)
######################################################################
......
......@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax)
plt.show(block=True)
######################################################################
......
......@@ -5,254 +5,277 @@ CTC forced alignment API tutorial
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__
This tutorial shows how to align transcripts to speech with
``torchaudio``'s CTC forced alignment API proposed in the paper
`“Scaling Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__,
and one advanced usage, i.e. dealing with transcription errors with a <star> token.
Though there’s some overlap in visualization
diagrams, the scope here is different from the `“Forced Alignment with
Wav2Vec2” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__
tutorial, which focuses on a step-by-step demonstration of the forced
alignment generation algorithm (without using an API) described in the
`paper <https://arxiv.org/abs/2007.09127>`__ with a Wav2Vec2 model.
This tutorial shows how to align transcripts to speech using
:py:func:`torchaudio.functional.forced_align`
which was developed along the work of
`Scaling Speech Technology to 1,000+ Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The forced alignment is a process to align transcript with speech.
We cover the basics of forced alignment in `Forced Alignment with
Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified
step-by-step Python implementations.
:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
implementations which are more performant than the vanilla Python
implementation above, and are more accurate.
It can also handle missing transcript with special <star> token.
For examples of aligning multiple languages, please refer to
`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__.
"""
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
from dataclasses import dataclass
from typing import List
try:
from torchaudio.functional import forced_align
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
import IPython
import matplotlib.pyplot as plt
######################################################################
# Basic usages
# ------------
#
# In this section, we cover the following content:
#
# 1. Generate frame-wise class probabilites from audio waveform from a CTC
# acoustic model.
# 2. Compute frame-level alignments using TorchAudio’s forced alignment
# API.
# 3. Obtain token-level alignments from frame-level alignments.
# 4. Obtain word-level alignments from token-level alignments.
#
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
######################################################################
# Preparation
# ~~~~~~~~~~~
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
# First we prepare the speech data and the transcript we area going
# to use.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
sample_rate = 16000
TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
######################################################################
# Generate frame-wise class posteriors from a CTC acoustic model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Generating emissions and tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`~torchaudio.functional.forced_align` takes emission and
# token sequences and outputs timestaps of the tokens and their scores.
#
# The first step is to generate the class probabilities (i.e. posteriors)
# of each audio frame using a CTC model.
# Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
# Emission reperesents the frame-wise probability distribution over
# tokens, and it can be obtained by passing waveform to an acoustic
# model.
# Tokens are numerical expression of transcripts. It can be obtained by
# simply mapping each character to the index of token list.
# The emission and the token sequences must be using the same set of tokens.
#
# We can use pre-trained Wav2Vec2 model to obtain emission from speech,
# and map transcript to tokens.
# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`,
# which bandles pre-trained model weights with associated labels.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
emission, _ = model(waveform.to(device))
emission = torch.log_softmax(emission, dim=-1)
######################################################################
#
emission = emissions.cpu().detach()
dictionary = {c: i for i, c in enumerate(labels)}
def plot_emission(emission):
plt.imshow(emission.cpu().T)
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.tight_layout()
print(dictionary)
plot_emission(emission[0])
######################################################################
# Visualization
# ^^^^^^^^^^^^^
#
# We create a dictionary, which maps each label into token.
labels = bundle.get_labels()
DICTIONARY = {c: i for i, c in enumerate(labels)}
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
######################################################################
# converting transcript to tokens is as simple as
plt.imshow(emission[0].T)
plt.colorbar()
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
tokens = [DICTIONARY[c] for c in TRANSCRIPT]
print(" ".join(str(t) for t in tokens))
######################################################################
# Computing frame-level alignments
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# --------------------------------
#
# Then we call TorchAudio’s forced alignment API to compute the
# frame-level alignment between each audio frame and each token in the
# transcript. We first explain the inputs and outputs of the API
# ``functional.forced_align``. Note that this API works on both CPU and
# GPU. In the current tutorial we demonstrate it on CPU.
# Now we call TorchAudio’s forced alignment API to compute the
# frame-level alignment. For the detail of function signature, please
# refer to :py:func:`~torchaudio.functional.forced_align`.
#
# **Inputs**:
#
# ``emission``: a 2D tensor of size :math:`T \times N`, where :math:`T` is
# the number of frames (after sub-sampling by the acoustic model, if any),
# and :math:`N` is the vocabulary size.
#
# ``targets``: a 1D tensor vector of size :math:`M`, where :math:`M` is
# the length of the transcript, and each element is a token ID looked up
# from the vocabulary. For example, the ``targets`` tensor repsenting the
# transcript “i had…” is :math:`[5, 18, 4, 16, ...]`.
#
# ``input lengths``: :math:`T`.
#
# ``target lengths``: :math:`M`.
#
# **Outputs**:
#
# ``frame_alignment``: a 1D tensor of size :math:`T` storing the aligned
# token index (looked up from the vocabulary) of each frame, e.g. for the
# segment corresponding to “i had” in the given example , the
# frame_alignment is
# :math:`[...0, 0, 5, 0, 0, 18, 18, 4, 0, 0, 0, 16,...]`, where :math:`0`
# represents the blank symbol.
def align(emission, tokens):
alignments, scores = forced_align(
emission,
targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
input_lengths=torch.tensor([emission.size(1)], device=emission.device),
target_lengths=torch.tensor([len(tokens)], device=emission.device),
blank=0,
)
scores = scores.exp() # convert back to probability
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
return alignments.tolist(), scores.tolist()
frame_alignment, frame_scores = align(emission, tokens)
######################################################################
# Now let's look at the output.
# Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}")
######################################################################
#
# ``frame_scores``: a 1D tensor of size :math:`T` storing the confidence
# score (0 to 1) for each each frame. For each frame, the score should be
# close to one if the alignment quality is good.
# The ``Frame`` instance represents the most likely token at each frame
# with its confidence.
#
# When interpreting it, one must remember that the meaning of blank token
# and repeated token are context dependent.
#
# .. note::
#
# When same token occured after blank tokens, it is not treated as
# a repeat, but as a new occurrence.
#
# .. code-block::
#
# a a a b -> a b
# a - - b -> a b
# a a - b -> a b
# a - a b -> a a b
# ^^^ ^^^
#
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # Start of "I"
# 31: 0 [-], 0.98 # repeat (blank token)
# 32: 0 [-], 1.00 # repeat (blank token)
# 33: 1 [|], 0.85 # Start of "|" (word boundary)
# 34: 1 [|], 1.00 # repeat (same token)
# 35: 0 [-], 0.61 # repeat (blank token)
# 36: 8 [H], 1.00 # Start of "H"
# 37: 0 [-], 1.00 # repeat (blank token)
# 38: 4 [A], 1.00 # Start of "A"
# 39: 0 [-], 0.99 # repeat (blank token)
# 40: 11 [D], 0.92 # Start of "D"
# 41: 0 [-], 0.93 # repeat (blank token)
# 42: 1 [|], 0.98 # Start of "|"
# 43: 1 [|], 1.00 # repeat (same token)
# 44: 3 [T], 1.00 # Start of "T"
# 45: 3 [T], 0.90 # repeat (same token)
# 46: 8 [H], 1.00 # Start of "H"
# 47: 0 [-], 1.00 # repeat (blank token)
######################################################################
# Resolve blank and repeated tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# From the outputs ``frame_alignment`` and ``frame_scores``, we generate a
# Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments.
# From the outputs ``alignment`` and ``scores``, we generate a
# list called ``frames`` storing information of all frames aligned to
# non-blank tokens. Each element contains 1) token_index: the aligned
# token’s index in the transcript 2) time_index: the current frame’s index
# in the input audio (or more precisely, the row dimension of the emission
# matrix) 3) the confidence scores of the current frame.
#
# For the given example, the first few elements of the list ``frames``
# corresponding to “i had” looks as the following:
#
# ``Frame(token_index=0, time_index=32, score=0.9994410872459412)``
# non-blank tokens.
#
# ``Frame(token_index=1, time_index=35, score=0.9980823993682861)``
# Each element contains the following
#
# ``Frame(token_index=1, time_index=36, score=0.9295750260353088)``
#
# ``Frame(token_index=2, time_index=37, score=0.9997448325157166)``
#
# ``Frame(token_index=3, time_index=41, score=0.9991760849952698)``
#
# ``...``
#
# The interpretation is:
#
# The token with index :math:`0` in the transcript, i.e. “i”, is aligned
# to the :math:`32`\ th audio frame, with confidence :math:`0.9994`. The
# token with index :math:`1` in the transcript, i.e. “h”, is aligned to
# the :math:`35`\ th and :math:`36`\ th audio frames, with confidence
# :math:`0.9981` and :math:`0.9296` respectively. The token with index
# :math:`2` in the transcript, i.e. “a”, is aligned to the :math:`35`\ th
# and :math:`36`\ th audio frames, with confidence :math:`0.9997`. The
# token with index :math:`3` in the transcript, i.e. “d”, is aligned to
# the :math:`41`\ th audio frame, with confidence :math:`0.9992`.
#
# From such information stored in the ``frames`` list, we’ll compute
# token-level and word-level alignments easily.
# - ``token_index``: the aligned token’s index **in the transcript**
# - ``time_index``: the current frame’s index in emission
# - ``score``: scores of the current frame.
#
# ``token_index`` is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
@dataclass
class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int
time_index: int
score: float
def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
######################################################################
#
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert frame_alignment.shape[1] == input_lengths[0].item()
assert targets.shape[1] == target_lengths[0].item()
def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
assert len(alignments) == len(scores)
token_index = -1
prev_hyp = 0
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
frames = []
for i, (ali, score) in enumerate(zip(alignments, scores)):
if ali == 0:
prev_hyp = 0
continue
if frame_alignment[0][i].item() != prev_hyp:
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
return frames, frame_alignment, frame_scores
transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
frames, frame_alignment, frame_scores = compute_alignments(transcript, dictionary, emission)
frames.append(Frame(token_index, i, score))
prev_hyp = ali
return frames
######################################################################
# Obtain token-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
frames = obtain_token_level_alignments(frame_alignment, frame_scores)
print("Time\tLabel\tScore")
for f in frames:
print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
######################################################################
# Obtain token-level alignments and confidence scores
# ---------------------------------------------------
#
# The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks
# (e.g. computing word-level alignments).
# (e.g. computing word-level alignments).
#
# Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging
# frame-level confidence scores.
#
######################################################################
# The following class represents the label, its score and the time span
# of its occurance.
#
# Merge the labels
@dataclass
class Segment:
label: str
......@@ -261,13 +284,16 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property
def length(self):
def __len__(self):
return self.end - self.start
######################################################################
#
def merge_repeats(frames, transcript):
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
......@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript):
return segments
segments = merge_repeats(frames, transcript)
######################################################################
#
segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments:
print(seg)
######################################################################
# Visualization
# ^^^^^^^^^^^^^
# ~~~~~~~~~~~~~
#
def plot_label_prob(segments, transcript):
fig, ax2 = plt.subplots(figsize=(16, 4))
fig, ax = plt.subplots()
ax2.set_title("frame-level and token-level confidence scores")
ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
for p in frames:
......@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript):
xs.append(p.time_index + 1)
hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black")
ax2.set_ylim(-0.1, 1.1)
ax.bar(xs, hs, width=0.5, alpha=0.5)
ax.set_ylim(-0.1, 1.1)
ax.grid(True, axis="y")
fig.tight_layout()
plot_label_prob(segments, transcript)
plt.tight_layout()
plt.show()
plot_label_prob(segments, TRANSCRIPT)
######################################################################
# From the visualized scores, we can see that, for tokens spanning over
# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in
# the previous tutorial `“Forced Alignment with
# Wav2Vec2 <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__.
# the previous tutorial
# `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
#
######################################################################
# Obtain word-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# --------------------------------------------------
#
......@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "):
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2
else:
......@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "):
return words
word_segments = merge_words(transcript, segments, "|")
word_segments = merge_words(TRANSCRIPT, segments, "|")
######################################################################
# Visualization
# ^^^^^^^^^^^^^
# ~~~~~~~~~~~~~
#
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({"font.size": 30})
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate):
fig, ax = plt.subplots()
# The original waveform
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
ax.specgram(waveform[0], Fs=sample_rate)
# The original waveform
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
t0, t1 = ratio * word.start, ratio * word.end
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale))
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / sample_rate, fontsize=50)
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
ax.set_xlabel("time [second]")
fig.tight_layout()
plot_alignments(
segments,
word_segments,
waveform,
emission.shape[1],
1,
)
plt.show()
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, frame_alignment):
ratio = waveform.size(1) / frame_alignment.size(1)
def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / len(frame_alignment)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
......@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment):
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
######################################################################
# Generate the audio for each segment
print(transcript)
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
######################################################################
......@@ -488,62 +503,71 @@ display_segment(8, waveform, word_segments, frame_alignment)
######################################################################
# Advanced usage: Dealing with missing transcripts using the <star> token
# ---------------------------------------------------------------------------
# Advanced: Handling transcripts with ``<star>`` token
# ----------------------------------------------------
#
# Now let’s look at when the transcript is partially missing, how can we
# improve alignment quality using the <star> token, which is capable of modeling
# improve alignment quality using the ``<star>`` token, which is capable of modeling
# any token.
#
# Here we use the same English example as used above. But we remove the
# beginning text “i had that curiosity beside me at” from the transcript.
# beginning text ``“i had that curiosity beside me at”`` from the transcript.
# Aligning audio with such transcript results in wrong alignments of the
# existing word “this”. However, this issue can be mitigated by using the
# <star> token to model the missing text.
# ``<star>`` token to model the missing text.
#
# Reload the emission tensor in order to add the extra dimension corresponding to the <star> token.
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
######################################################################
# First, we extend the dictionary to include the ``<star>`` token.
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
DICTIONARY["*"] = len(DICTIONARY)
# Extend the dictionary to include the <star> token.
dictionary["*"] = 29
######################################################################
# Next, we extend the emission tensor with the extra dimension
# corresponding to the ``<star>`` token.
#
assert len(dictionary) == emission.shape[2]
extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device)
emission = torch.cat((emission, extra_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
######################################################################
# The following function combines all the processes, and compute
# word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
tokens = [dictionary[c] for c in transcript]
alignment, scores = align(emission, tokens)
frames = obtain_token_level_alignments(alignment, scores)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|")
plot_alignments(segments, word_segments, waveform, emission.shape[1], 1)
plt.show()
return word_segments, frame_alignment
plot_alignments(waveform, emission, segments, word_segments)
plt.xlim([0, None])
# original:
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
######################################################################
# **Original**
# Demonstrate the effect of <star> token for dealing with deletion errors
# ("i had that curiosity beside me at" missing from the transcript):
transcript = "THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform)
######################################################################
# **With <star> token**
#
# Now we replace the first part of the transcript with the ``<star>`` token.
# Replacing the missing transcript with the <star> token:
transcript = "*|THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
######################################################################
# **Without <star> token**
#
# As a comparison, the following aligns the partial transcript
# without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform)
######################################################################
# Conclusion
......@@ -551,7 +575,7 @@ word_segments, frame_alignment = compute_and_plot_alignments(transcript, diction
#
# In this tutorial, we looked at how to use torchaudio’s forced alignment
# API to align and segment speech files, and demonstrated one advanced usage:
# How introducing a <star> token could improve alignment accuracy when
# How introducing a ``<star>`` token could improve alignment accuracy when
# transcription errors exist.
#
......
......@@ -69,7 +69,7 @@ import torchvision
# -------------------
#
# Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader`
# camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the
......
......@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape
half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)"
)
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout()
fig.tight_layout()
######################################################################
......@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)"
)
plt.tight_layout()
fig.tight_layout()
######################################################################
......@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True)
axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right")
axes[2].legend(loc="center right")
fig.tight_layout()
......
......@@ -6,15 +6,14 @@ Forced alignment for multilingual data
This tutorial shows how to compute forced alignments for speech data
from multiple non-English languages using ``torchaudio``'s CTC forced alignment
API described in `“CTC forced alignment
tutorial” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__
and the multilingual Wav2vec2 model proposed in the paper `“Scaling
API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__
and the multilingual Wav2vec2 model proposed in the paper `Scaling
Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The model was trained on 23K of audio data from 1100+ languages using
the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__
the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__
as targets.
"""
import torch
......@@ -23,53 +22,46 @@ import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
try:
from torchaudio.functional import forced_align
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
######################################################################
# Preparation
# -----------
#
# Here we import necessary packages, and define utility functions for
# computing the frame-level alignments (using the API
# ``functional.forced_align``), token-level and word-level alignments, and
# also alignment visualization utilities.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib.pyplot as plt
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
sample_rate = 16000
######################################################################
#
SAMPLE_RATE = 16000
######################################################################
#
# Here we define utility functions for computing the frame-level
# alignments (using the API :py:func:`torchaudio.functional.forced_align`),
# token-level and word-level alignments.
# For the detail of these functions please refer to
# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__.
#
@dataclass
class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int
time_index: int
score: float
######################################################################
#
@dataclass
class Segment:
label: str
......@@ -78,39 +70,42 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property
def length(self):
def __len__(self):
return self.end - self.start
# compute frame-level and word-level alignments using torchaudio's forced alignment API
######################################################################
#
def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
# This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
scores = scores.exp() # convert back to probability
alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert frame_alignment.shape[1] == input_lengths[0].item()
assert targets.shape[1] == target_lengths[0].item()
assert len(alignment) == len(scores) == emission.size(1)
token_index = -1
prev_hyp = 0
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
frames = []
for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0
continue
if frame_alignment[0][i].item() != prev_hyp:
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
frames.append(Frame(token_index, i, score))
prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
......@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission):
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
segs = segments[i1:i2]
word = "".join([s.label for s in segs])
score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
return segments, words
num_frames = frame_alignment.shape[1]
return segments, words, num_frames
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
# utility function for plotting word alignments
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({"font.size": 30})
# The original waveform
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
######################################################################
#
# utility function for plotting word alignments
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim()
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
t0, t1 = word.start * ratio, word.end * ratio
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale))
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / sample_rate, fontsize=50)
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
ax.set_xlabel("time [second]")
ax.set_xlim(xlim)
fig.tight_layout()
return IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
# utility function for playing audio segments.
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, num_frames):
def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames
word = word_segments[i]
x0 = int(ratio * word.start)
......@@ -241,26 +243,21 @@ model.load_state_dict(
)
)
model.eval()
model.to(device)
def get_emission(waveform):
with torch.inference_mode():
# NOTE: this step is essential
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
emissions, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions.cpu().detach()
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
return emission, waveform
emission, _ = model(waveform)
return torch.log_softmax(emission, dim=-1)
# Construct the dictionary
# '@' represents the OOV token, '*' represents the <star> token.
# '@' represents the OOV token
# <pad> and </s> are fairseq's legacy tokens, which're not used.
# <star> token is omitted as we do not use it in this tutorial
dictionary = {
"<blank>": 0,
"<pad>": 1,
......@@ -293,7 +290,6 @@ dictionary = {
"'": 28,
"q": 29,
"x": 30,
"*": 31,
}
......@@ -304,11 +300,8 @@ dictionary = {
# romanizer and using it to obtain romanized transcripts, and PyThon
# commands required for further normalizing the romanized transcript.
#
# %%
# .. code-block:: bash
#
# %%bash
# Save the raw transcript to a file
# echo 'raw text' > text.txt
# git clone https://github.com/isi-nlp/uroman
......@@ -334,141 +327,77 @@ dictionary = {
######################################################################
# German example:
# ~~~~~~~~~~~~~~~~
# German
# ~~~~~~
text_raw = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
text_normalized = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
assert len(dictionary) == emission.shape[2]
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
text_raw = "aber seit ich bei ihnen das brot hole"
text_normalized = "aber seit ich bei ihnen das brot hole"
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(13, waveform, word_segments, num_frames)
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
display_segment(16, waveform, word_segments, num_frames)
######################################################################
#
display_segment(17, waveform, word_segments, num_frames)
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(18, waveform, word_segments, num_frames)
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(19, waveform, word_segments, num_frames)
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(20, waveform, word_segments, num_frames)
display_segment(7, waveform, word_segments, num_frames)
######################################################################
# Chinese example:
# ~~~~~~~~~~~~~~~~
# Chinese
# ~~~~~~~
#
# Chinese is a character-based language, and there is not explicit word-level
# tokenization (separated by spaces) in its raw written form. In order to
......@@ -478,98 +407,36 @@ display_segment(20, waveform, word_segments, num_frames)
# However this is not needed if you only want character-level alignments.
#
text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False)
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
display_segment(7, waveform, word_segments, num_frames)
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
# Polish example:
# ~~~~~~~~~~~~~~~
text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę dlaczego mi nie powiedziałeś szepnąłem ze łzami"
text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane dlaczego mi nie powiedziales szepnalem ze lzami"
speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
......@@ -585,7 +452,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
......@@ -611,68 +477,40 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
# Polish
# ~~~~~~
display_segment(10, waveform, word_segments, num_frames)
speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
######################################################################
#
text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę"
text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane"
display_segment(11, waveform, word_segments, num_frames)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
######################################################################
#
waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
display_segment(13, waveform, word_segments, num_frames)
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
# Portuguese example:
# ~~~~~~~~~~~~~~~~~~~
text_raw = (
"mas na imensa extensão onde se esconde o inconsciente imortal só me responde um bramido um queixume e nada mais"
)
text_normalized = (
"mas na imensa extensao onde se esconde o inconsciente imortal so me responde um bramido um queixume e nada mais"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
......@@ -688,7 +526,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
......@@ -710,94 +547,38 @@ display_segment(6, waveform, word_segments, num_frames)
display_segment(7, waveform, word_segments, num_frames)
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
# Portuguese
# ~~~~~~~~~~
######################################################################
#
display_segment(13, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
######################################################################
#
text_raw = "na imensa extensão onde se esconde o inconsciente imortal"
text_normalized = "na imensa extensao onde se esconde o inconsciente imortal"
display_segment(16, waveform, word_segments, num_frames)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
######################################################################
#
display_segment(17, waveform, word_segments, num_frames)
waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE))
######################################################################
#
display_segment(18, waveform, word_segments, num_frames)
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(19, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
# Italian example:
# ~~~~~~~~~~~~~~~~
text_raw = "elle giacean per terra tutte quante fuor d'una ch'a seder si levò ratto ch'ella ci vide passarsi davante"
text_normalized = (
"elle giacean per terra tutte quante fuor d'una ch'a seder si levo ratto ch'ella ci vide passarsi davante"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
......@@ -813,7 +594,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
......@@ -840,50 +620,62 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
# Italian
# ~~~~~~~
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
text_raw = "elle giacean per terra tutte quante"
text_normalized = "elle giacean per terra tutte quante"
display_segment(9, waveform, word_segments, num_frames)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(13, waveform, word_segments, num_frames)
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(16, waveform, word_segments, num_frames)
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(17, waveform, word_segments, num_frames)
display_segment(5, waveform, word_segments, num_frames)
######################################################################
# Conclusion
......@@ -894,7 +686,6 @@ display_segment(17, waveform, word_segments, num_frames)
# speech data to transcripts in five languages.
#
######################################################################
# Acknowledgement
# ---------------
......
......@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
......@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach()
print(labels)
################################################################################
# Visualization
################################################################################
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
# ~~~~~~~~~~~~~
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
######################################################################
# Generate alignment probability (trellis)
......@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plt.show()
# ~~~~~~~~~~~~~
def plot():
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
######################################################################
# In the above visualization, we can see that there is a trace of high
......@@ -266,7 +271,9 @@ for p in path:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
......@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
plt.show()
######################################################################
# Looking good. Now this path contains repetations for the same labels, so
# Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
......@@ -297,7 +308,7 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
......@@ -330,7 +341,9 @@ for seg in segments:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
......@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.set_xticks([])
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3))
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
......@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
......@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim())
ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
......@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
ax1.set_yticks([])
......@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7))
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8)
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
ratio = waveform.size(0) / trellis.size(0)
......@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments(
......@@ -458,7 +475,6 @@ plot_alignments(
word_segments,
waveform[0],
)
plt.show()
################################################################################
......
......@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
_, axis = plt.subplots(1, 1)
axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
axis.set_title(title)
plt.tight_layout()
######################################################################
......@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
)
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}")
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate)
......@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
#
# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture")
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)
######################################################################
......
......@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
#
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
axis.set_title(title)
plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None):
def plot_mask(mask, title="Mask"):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title)
axis.set_title(title)
plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8):
......
......@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os
import time
import matplotlib
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
######################################################################
#
# Check the prerequisites
......
......@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
fig.tight_layout()
######################################################################
......@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
plt.tight_layout()
print("Class labels:", bundle.get_labels())
......
......@@ -82,6 +82,7 @@ try:
from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
try:
import google.colab # noqa: F401
print(
......@@ -95,6 +96,9 @@ except ImportError:
!pip3 install pystoi
"""
)
except Exception:
pass
raise
import matplotlib.pyplot as plt
......@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item()
def plot_waveform(waveform, title):
def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy()
sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000
time_axis = torch.arange(0, sample_size) / sample_rate
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.plot(time_axis, wav_numpy[0], linewidth=1)
axes.grid(True)
figure, axes = plt.subplots(2, 1)
axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
axes[0].grid(True)
axes[1].specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
def plot_specgram(waveform, sample_rate, title):
wav_numpy = waveform.numpy()
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample
#
plot_waveform(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
plot(WAVEFORM_SPEECH, "Clean Speech")
######################################################################
# Visualize noise sample
#
plot_waveform(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
plot(WAVEFORM_NOISE, "Noise")
######################################################################
# Visualize distorted speech with 20dB SNR
#
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
######################################################################
# Visualize distorted speech with -5dB SNR
#
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment