Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
......@@ -127,6 +127,22 @@ def _get_pattern():
return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = {
"examples_dirs": [
"../../examples/tutorials",
......@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True,
"first_notebook_cell": None,
"doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
}
autosummary_generate = True
......
......@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial
tutorials/device_asr
tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial
......@@ -147,6 +147,13 @@ Tutorials
.. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem::
:header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
......@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
#
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True)
......@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4):
axes[i].grid(True)
pos = axes[2].get_position()
plt.tight_layout()
fig.tight_layout()
if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
......@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
#
......@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Square wave
......@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Triangle wave
......@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
#
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Inharmonic Paritials
......@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
######################################################################
#
......@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE)
plot(freq, amp, waveform, SAMPLE_RATE)
######################################################################
# References
......
......@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
#
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
t = torch.arange(waveform.size(0)) / sample_rate
ratio = waveform.size(0) / emission.size(1) / sample_rate
chars = []
words = []
word_start = None
for token, timestep in zip(tokens, timesteps * ratio):
if token == "|":
if word_start is not None:
words.append((word_start, timestep))
word_start = None
else:
chars.append((token, timestep))
if word_start is None:
word_start = timestep
fig, axes = plt.subplots(3, 1)
def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
######################################################################
......
......@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......
# -*- coding: utf-8 -*-
"""
Audio Datasets
==============
......@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets.
"""
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch
import torchaudio
......@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os
import IPython
import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
......@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
figure.tight_layout()
######################################################################
......@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset.
#
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]:
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate)
######################################################################
#
i = 1
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
......@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# Preparation
# -----------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
......@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
######################################################################
......@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None)
stretch = T.TimeStretch()
rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
spec_12 = stretch(spec, overriding_rate=1.2)
spec_09 = stretch(spec, overriding_rate=0.9)
######################################################################
# TimeMasking
# -----------
#
torch.random.manual_seed(4)
spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original")
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis")
plot()
######################################################################
# FrequencyMasking
# ----------------
# Time and Frequency Masking
# --------------------------
#
torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original")
time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis")
plot()
......@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True)
ax.set_xlim([0, time_axis[-1]])
ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
......@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None):
......@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
plt.show(block=False)
######################################################################
......@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
......@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform")
plt.show(block=False)
######################################################################
......@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title)
plt.show(block=False)
######################################################################
......
......@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax)
plt.show(block=True)
######################################################################
......
......@@ -5,254 +5,277 @@ CTC forced alignment API tutorial
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__
This tutorial shows how to align transcripts to speech with
``torchaudio``'s CTC forced alignment API proposed in the paper
`“Scaling Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__,
and one advanced usage, i.e. dealing with transcription errors with a <star> token.
Though there’s some overlap in visualization
diagrams, the scope here is different from the `“Forced Alignment with
Wav2Vec2” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__
tutorial, which focuses on a step-by-step demonstration of the forced
alignment generation algorithm (without using an API) described in the
`paper <https://arxiv.org/abs/2007.09127>`__ with a Wav2Vec2 model.
This tutorial shows how to align transcripts to speech using
:py:func:`torchaudio.functional.forced_align`
which was developed along the work of
`Scaling Speech Technology to 1,000+ Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The forced alignment is a process to align transcript with speech.
We cover the basics of forced alignment in `Forced Alignment with
Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified
step-by-step Python implementations.
:py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
implementations which are more performant than the vanilla Python
implementation above, and are more accurate.
It can also handle missing transcript with special <star> token.
For examples of aligning multiple languages, please refer to
`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__.
"""
import torch
import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
######################################################################
#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
from dataclasses import dataclass
from typing import List
try:
from torchaudio.functional import forced_align
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
import IPython
import matplotlib.pyplot as plt
######################################################################
# Basic usages
# ------------
#
# In this section, we cover the following content:
#
# 1. Generate frame-wise class probabilites from audio waveform from a CTC
# acoustic model.
# 2. Compute frame-level alignments using TorchAudio’s forced alignment
# API.
# 3. Obtain token-level alignments from frame-level alignments.
# 4. Obtain word-level alignments from token-level alignments.
#
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
######################################################################
# Preparation
# ~~~~~~~~~~~
# -----------
#
# First we import the necessary packages, and fetch data that we work on.
# First we prepare the speech data and the transcript we area going
# to use.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
sample_rate = 16000
TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
######################################################################
# Generate frame-wise class posteriors from a CTC acoustic model
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Generating emissions and tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`~torchaudio.functional.forced_align` takes emission and
# token sequences and outputs timestaps of the tokens and their scores.
#
# The first step is to generate the class probabilities (i.e. posteriors)
# of each audio frame using a CTC model.
# Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`.
# Emission reperesents the frame-wise probability distribution over
# tokens, and it can be obtained by passing waveform to an acoustic
# model.
# Tokens are numerical expression of transcripts. It can be obtained by
# simply mapping each character to the index of token list.
# The emission and the token sequences must be using the same set of tokens.
#
# We can use pre-trained Wav2Vec2 model to obtain emission from speech,
# and map transcript to tokens.
# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`,
# which bandles pre-trained model weights with associated labels.
#
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
emission, _ = model(waveform.to(device))
emission = torch.log_softmax(emission, dim=-1)
######################################################################
#
emission = emissions.cpu().detach()
dictionary = {c: i for i, c in enumerate(labels)}
def plot_emission(emission):
plt.imshow(emission.cpu().T)
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.tight_layout()
print(dictionary)
plot_emission(emission[0])
######################################################################
# Visualization
# ^^^^^^^^^^^^^
#
# We create a dictionary, which maps each label into token.
labels = bundle.get_labels()
DICTIONARY = {c: i for i, c in enumerate(labels)}
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
######################################################################
# converting transcript to tokens is as simple as
plt.imshow(emission[0].T)
plt.colorbar()
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
tokens = [DICTIONARY[c] for c in TRANSCRIPT]
print(" ".join(str(t) for t in tokens))
######################################################################
# Computing frame-level alignments
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# --------------------------------
#
# Then we call TorchAudio’s forced alignment API to compute the
# frame-level alignment between each audio frame and each token in the
# transcript. We first explain the inputs and outputs of the API
# ``functional.forced_align``. Note that this API works on both CPU and
# GPU. In the current tutorial we demonstrate it on CPU.
# Now we call TorchAudio’s forced alignment API to compute the
# frame-level alignment. For the detail of function signature, please
# refer to :py:func:`~torchaudio.functional.forced_align`.
#
# **Inputs**:
#
# ``emission``: a 2D tensor of size :math:`T \times N`, where :math:`T` is
# the number of frames (after sub-sampling by the acoustic model, if any),
# and :math:`N` is the vocabulary size.
#
# ``targets``: a 1D tensor vector of size :math:`M`, where :math:`M` is
# the length of the transcript, and each element is a token ID looked up
# from the vocabulary. For example, the ``targets`` tensor repsenting the
# transcript “i had…” is :math:`[5, 18, 4, 16, ...]`.
#
# ``input lengths``: :math:`T`.
#
# ``target lengths``: :math:`M`.
#
# **Outputs**:
#
# ``frame_alignment``: a 1D tensor of size :math:`T` storing the aligned
# token index (looked up from the vocabulary) of each frame, e.g. for the
# segment corresponding to “i had” in the given example , the
# frame_alignment is
# :math:`[...0, 0, 5, 0, 0, 18, 18, 4, 0, 0, 0, 16,...]`, where :math:`0`
# represents the blank symbol.
def align(emission, tokens):
alignments, scores = forced_align(
emission,
targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
input_lengths=torch.tensor([emission.size(1)], device=emission.device),
target_lengths=torch.tensor([len(tokens)], device=emission.device),
blank=0,
)
scores = scores.exp() # convert back to probability
alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
return alignments.tolist(), scores.tolist()
frame_alignment, frame_scores = align(emission, tokens)
######################################################################
# Now let's look at the output.
# Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}")
######################################################################
#
# ``frame_scores``: a 1D tensor of size :math:`T` storing the confidence
# score (0 to 1) for each each frame. For each frame, the score should be
# close to one if the alignment quality is good.
# The ``Frame`` instance represents the most likely token at each frame
# with its confidence.
#
# When interpreting it, one must remember that the meaning of blank token
# and repeated token are context dependent.
#
# .. note::
#
# When same token occured after blank tokens, it is not treated as
# a repeat, but as a new occurrence.
#
# .. code-block::
#
# a a a b -> a b
# a - - b -> a b
# a a - b -> a b
# a - a b -> a a b
# ^^^ ^^^
#
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # Start of "I"
# 31: 0 [-], 0.98 # repeat (blank token)
# 32: 0 [-], 1.00 # repeat (blank token)
# 33: 1 [|], 0.85 # Start of "|" (word boundary)
# 34: 1 [|], 1.00 # repeat (same token)
# 35: 0 [-], 0.61 # repeat (blank token)
# 36: 8 [H], 1.00 # Start of "H"
# 37: 0 [-], 1.00 # repeat (blank token)
# 38: 4 [A], 1.00 # Start of "A"
# 39: 0 [-], 0.99 # repeat (blank token)
# 40: 11 [D], 0.92 # Start of "D"
# 41: 0 [-], 0.93 # repeat (blank token)
# 42: 1 [|], 0.98 # Start of "|"
# 43: 1 [|], 1.00 # repeat (same token)
# 44: 3 [T], 1.00 # Start of "T"
# 45: 3 [T], 0.90 # repeat (same token)
# 46: 8 [H], 1.00 # Start of "H"
# 47: 0 [-], 1.00 # repeat (blank token)
######################################################################
# Resolve blank and repeated tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# From the outputs ``frame_alignment`` and ``frame_scores``, we generate a
# Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments.
# From the outputs ``alignment`` and ``scores``, we generate a
# list called ``frames`` storing information of all frames aligned to
# non-blank tokens. Each element contains 1) token_index: the aligned
# token’s index in the transcript 2) time_index: the current frame’s index
# in the input audio (or more precisely, the row dimension of the emission
# matrix) 3) the confidence scores of the current frame.
#
# For the given example, the first few elements of the list ``frames``
# corresponding to “i had” looks as the following:
#
# ``Frame(token_index=0, time_index=32, score=0.9994410872459412)``
# non-blank tokens.
#
# ``Frame(token_index=1, time_index=35, score=0.9980823993682861)``
# Each element contains the following
#
# ``Frame(token_index=1, time_index=36, score=0.9295750260353088)``
#
# ``Frame(token_index=2, time_index=37, score=0.9997448325157166)``
#
# ``Frame(token_index=3, time_index=41, score=0.9991760849952698)``
#
# ``...``
#
# The interpretation is:
#
# The token with index :math:`0` in the transcript, i.e. “i”, is aligned
# to the :math:`32`\ th audio frame, with confidence :math:`0.9994`. The
# token with index :math:`1` in the transcript, i.e. “h”, is aligned to
# the :math:`35`\ th and :math:`36`\ th audio frames, with confidence
# :math:`0.9981` and :math:`0.9296` respectively. The token with index
# :math:`2` in the transcript, i.e. “a”, is aligned to the :math:`35`\ th
# and :math:`36`\ th audio frames, with confidence :math:`0.9997`. The
# token with index :math:`3` in the transcript, i.e. “d”, is aligned to
# the :math:`41`\ th audio frame, with confidence :math:`0.9992`.
#
# From such information stored in the ``frames`` list, we’ll compute
# token-level and word-level alignments easily.
# - ``token_index``: the aligned token’s index **in the transcript**
# - ``time_index``: the current frame’s index in emission
# - ``score``: scores of the current frame.
#
# ``token_index`` is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
@dataclass
class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int
time_index: int
score: float
def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
######################################################################
#
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert frame_alignment.shape[1] == input_lengths[0].item()
assert targets.shape[1] == target_lengths[0].item()
def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
assert len(alignments) == len(scores)
token_index = -1
prev_hyp = 0
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
frames = []
for i, (ali, score) in enumerate(zip(alignments, scores)):
if ali == 0:
prev_hyp = 0
continue
if frame_alignment[0][i].item() != prev_hyp:
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
return frames, frame_alignment, frame_scores
transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
frames, frame_alignment, frame_scores = compute_alignments(transcript, dictionary, emission)
frames.append(Frame(token_index, i, score))
prev_hyp = ali
return frames
######################################################################
# Obtain token-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
frames = obtain_token_level_alignments(frame_alignment, frame_scores)
print("Time\tLabel\tScore")
for f in frames:
print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
######################################################################
# Obtain token-level alignments and confidence scores
# ---------------------------------------------------
#
# The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks
# (e.g. computing word-level alignments).
# (e.g. computing word-level alignments).
#
# Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging
# frame-level confidence scores.
#
######################################################################
# The following class represents the label, its score and the time span
# of its occurance.
#
# Merge the labels
@dataclass
class Segment:
label: str
......@@ -261,13 +284,16 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property
def length(self):
def __len__(self):
return self.end - self.start
######################################################################
#
def merge_repeats(frames, transcript):
transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0
......@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript):
return segments
segments = merge_repeats(frames, transcript)
######################################################################
#
segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments:
print(seg)
######################################################################
# Visualization
# ^^^^^^^^^^^^^
# ~~~~~~~~~~~~~
#
def plot_label_prob(segments, transcript):
fig, ax2 = plt.subplots(figsize=(16, 4))
fig, ax = plt.subplots()
ax2.set_title("frame-level and token-level confidence scores")
ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], []
for seg in segments:
if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
for p in frames:
......@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript):
xs.append(p.time_index + 1)
hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black")
ax2.set_ylim(-0.1, 1.1)
ax.bar(xs, hs, width=0.5, alpha=0.5)
ax.set_ylim(-0.1, 1.1)
ax.grid(True, axis="y")
fig.tight_layout()
plot_label_prob(segments, transcript)
plt.tight_layout()
plt.show()
plot_label_prob(segments, TRANSCRIPT)
######################################################################
# From the visualized scores, we can see that, for tokens spanning over
# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in
# the previous tutorial `“Forced Alignment with
# Wav2Vec2 <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__.
# the previous tutorial
# `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
#
######################################################################
# Obtain word-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# --------------------------------------------------
#
......@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "):
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2
else:
......@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "):
return words
word_segments = merge_words(transcript, segments, "|")
word_segments = merge_words(TRANSCRIPT, segments, "|")
######################################################################
# Visualization
# ^^^^^^^^^^^^^
# ~~~~~~~~~~~~~
#
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({"font.size": 30})
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate):
fig, ax = plt.subplots()
# The original waveform
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
ax.specgram(waveform[0], Fs=sample_rate)
# The original waveform
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
t0, t1 = ratio * word.start, ratio * word.end
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale))
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / sample_rate, fontsize=50)
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
ax.set_xlabel("time [second]")
fig.tight_layout()
plot_alignments(
segments,
word_segments,
waveform,
emission.shape[1],
1,
)
plt.show()
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, frame_alignment):
ratio = waveform.size(1) / frame_alignment.size(1)
def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate):
ratio = waveform.size(1) / len(frame_alignment)
word = word_segments[i]
x0 = int(ratio * word.start)
x1 = int(ratio * word.end)
......@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment):
return IPython.display.Audio(segment.numpy(), rate=sample_rate)
######################################################################
# Generate the audio for each segment
print(transcript)
print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE)
######################################################################
......@@ -488,62 +503,71 @@ display_segment(8, waveform, word_segments, frame_alignment)
######################################################################
# Advanced usage: Dealing with missing transcripts using the <star> token
# ---------------------------------------------------------------------------
# Advanced: Handling transcripts with ``<star>`` token
# ----------------------------------------------------
#
# Now let’s look at when the transcript is partially missing, how can we
# improve alignment quality using the <star> token, which is capable of modeling
# improve alignment quality using the ``<star>`` token, which is capable of modeling
# any token.
#
# Here we use the same English example as used above. But we remove the
# beginning text “i had that curiosity beside me at” from the transcript.
# beginning text ``“i had that curiosity beside me at”`` from the transcript.
# Aligning audio with such transcript results in wrong alignments of the
# existing word “this”. However, this issue can be mitigated by using the
# <star> token to model the missing text.
# ``<star>`` token to model the missing text.
#
# Reload the emission tensor in order to add the extra dimension corresponding to the <star> token.
with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
######################################################################
# First, we extend the dictionary to include the ``<star>`` token.
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
DICTIONARY["*"] = len(DICTIONARY)
# Extend the dictionary to include the <star> token.
dictionary["*"] = 29
######################################################################
# Next, we extend the emission tensor with the extra dimension
# corresponding to the ``<star>`` token.
#
assert len(dictionary) == emission.shape[2]
extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device)
emission = torch.cat((emission, extra_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
######################################################################
# The following function combines all the processes, and compute
# word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission)
tokens = [dictionary[c] for c in transcript]
alignment, scores = align(emission, tokens)
frames = obtain_token_level_alignments(alignment, scores)
segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|")
plot_alignments(segments, word_segments, waveform, emission.shape[1], 1)
plt.show()
return word_segments, frame_alignment
plot_alignments(waveform, emission, segments, word_segments)
plt.xlim([0, None])
# original:
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
######################################################################
# **Original**
# Demonstrate the effect of <star> token for dealing with deletion errors
# ("i had that curiosity beside me at" missing from the transcript):
transcript = "THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform)
######################################################################
# **With <star> token**
#
# Now we replace the first part of the transcript with the ``<star>`` token.
# Replacing the missing transcript with the <star> token:
transcript = "*|THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
######################################################################
# **Without <star> token**
#
# As a comparison, the following aligns the partial transcript
# without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform)
######################################################################
# Conclusion
......@@ -551,7 +575,7 @@ word_segments, frame_alignment = compute_and_plot_alignments(transcript, diction
#
# In this tutorial, we looked at how to use torchaudio’s forced alignment
# API to align and segment speech files, and demonstrated one advanced usage:
# How introducing a <star> token could improve alignment accuracy when
# How introducing a ``<star>`` token could improve alignment accuracy when
# transcription errors exist.
#
......
......@@ -69,7 +69,7 @@ import torchvision
# -------------------
#
# Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader`
# camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the
......
......@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape
half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)"
)
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout()
fig.tight_layout()
######################################################################
......@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)"
)
plt.tight_layout()
fig.tight_layout()
######################################################################
......@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True)
axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right")
axes[2].legend(loc="center right")
fig.tight_layout()
......
......@@ -6,15 +6,14 @@ Forced alignment for multilingual data
This tutorial shows how to compute forced alignments for speech data
from multiple non-English languages using ``torchaudio``'s CTC forced alignment
API described in `“CTC forced alignment
tutorial” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__
and the multilingual Wav2vec2 model proposed in the paper `“Scaling
API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__
and the multilingual Wav2vec2 model proposed in the paper `Scaling
Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The model was trained on 23K of audio data from 1100+ languages using
the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__
the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__
as targets.
"""
import torch
......@@ -23,53 +22,46 @@ import torchaudio
print(torch.__version__)
print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
try:
from torchaudio.functional import forced_align
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
######################################################################
# Preparation
# -----------
#
# Here we import necessary packages, and define utility functions for
# computing the frame-level alignments (using the API
# ``functional.forced_align``), token-level and word-level alignments, and
# also alignment visualization utilities.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib.pyplot as plt
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
sample_rate = 16000
######################################################################
#
SAMPLE_RATE = 16000
######################################################################
#
# Here we define utility functions for computing the frame-level
# alignments (using the API :py:func:`torchaudio.functional.forced_align`),
# token-level and word-level alignments.
# For the detail of these functions please refer to
# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__.
#
@dataclass
class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int
time_index: int
score: float
######################################################################
#
@dataclass
class Segment:
label: str
......@@ -78,39 +70,42 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property
def length(self):
def __len__(self):
return self.end - self.start
# compute frame-level and word-level alignments using torchaudio's forced alignment API
######################################################################
#
def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
# This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
scores = scores.exp() # convert back to probability
alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert frame_alignment.shape[1] == input_lengths[0].item()
assert targets.shape[1] == target_lengths[0].item()
assert len(alignment) == len(scores) == emission.size(1)
token_index = -1
prev_hyp = 0
for i in range(frame_alignment.shape[1]):
if frame_alignment[0][i].item() == 0:
frames = []
for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0
continue
if frame_alignment[0][i].item() != prev_hyp:
if ali != prev_hyp:
token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item()))
prev_hyp = frame_alignment[0][i].item()
frames.append(Frame(token_index, i, score))
prev_hyp = ali
# compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "")
......@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission):
if i1 != i2:
if i3 == len(transcript) - 1:
i2 += 1
s = 0
segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
segs = segments[i1:i2]
word = "".join([s.label for s in segs])
score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
i1 = i2
else:
i2 += 1
i3 += 1
return segments, words
num_frames = frame_alignment.shape[1]
return segments, words, num_frames
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
# utility function for plotting word alignments
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({"font.size": 30})
# The original waveform
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
######################################################################
#
# utility function for plotting word alignments
def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim()
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments:
x0 = ratio * word.start
x1 = ratio * word.end
ax2.axvspan(x0, x1, alpha=0.1, color="red")
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
t0, t1 = word.start * ratio, word.end * ratio
ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
for seg in segments:
if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale))
ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks()
plt.xticks(xticks, xticks / sample_rate, fontsize=50)
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
ax.set_xlabel("time [second]")
ax.set_xlim(xlim)
fig.tight_layout()
return IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
# utility function for playing audio segments.
# A trick to embed the resulting audio to the generated file.
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, num_frames):
def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
ratio = waveform.size(1) / num_frames
word = word_segments[i]
x0 = int(ratio * word.start)
......@@ -241,26 +243,21 @@ model.load_state_dict(
)
)
model.eval()
model.to(device)
def get_emission(waveform):
# NOTE: this step is essential
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
emissions, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions.cpu().detach()
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
return emission, waveform
with torch.inference_mode():
# NOTE: this step is essential
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
emission, _ = model(waveform)
return torch.log_softmax(emission, dim=-1)
# Construct the dictionary
# '@' represents the OOV token, '*' represents the <star> token.
# '@' represents the OOV token
# <pad> and </s> are fairseq's legacy tokens, which're not used.
# <star> token is omitted as we do not use it in this tutorial
dictionary = {
"<blank>": 0,
"<pad>": 1,
......@@ -293,7 +290,6 @@ dictionary = {
"'": 28,
"q": 29,
"x": 30,
"*": 31,
}
......@@ -304,11 +300,8 @@ dictionary = {
# romanizer and using it to obtain romanized transcripts, and PyThon
# commands required for further normalizing the romanized transcript.
#
# %%
# .. code-block:: bash
#
# %%bash
# Save the raw transcript to a file
# echo 'raw text' > text.txt
# git clone https://github.com/isi-nlp/uroman
......@@ -334,141 +327,77 @@ dictionary = {
######################################################################
# German example:
# ~~~~~~~~~~~~~~~~
# German
# ~~~~~~
text_raw = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
text_normalized = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
assert len(dictionary) == emission.shape[2]
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
text_raw = "aber seit ich bei ihnen das brot hole"
text_normalized = "aber seit ich bei ihnen das brot hole"
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(13, waveform, word_segments, num_frames)
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
display_segment(16, waveform, word_segments, num_frames)
######################################################################
#
display_segment(17, waveform, word_segments, num_frames)
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(18, waveform, word_segments, num_frames)
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(19, waveform, word_segments, num_frames)
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(20, waveform, word_segments, num_frames)
display_segment(7, waveform, word_segments, num_frames)
######################################################################
# Chinese example:
# ~~~~~~~~~~~~~~~~
# Chinese
# ~~~~~~~
#
# Chinese is a character-based language, and there is not explicit word-level
# tokenization (separated by spaces) in its raw written form. In order to
......@@ -478,98 +407,36 @@ display_segment(20, waveform, word_segments, num_frames)
# However this is not needed if you only want character-level alignments.
#
text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False)
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
display_segment(7, waveform, word_segments, num_frames)
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
# Polish example:
# ~~~~~~~~~~~~~~~
text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę dlaczego mi nie powiedziałeś szepnąłem ze łzami"
text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane dlaczego mi nie powiedziales szepnalem ze lzami"
speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
......@@ -585,7 +452,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
......@@ -611,68 +477,40 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
# Polish
# ~~~~~~
display_segment(10, waveform, word_segments, num_frames)
speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
######################################################################
#
text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę"
text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane"
display_segment(11, waveform, word_segments, num_frames)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
######################################################################
#
waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
display_segment(13, waveform, word_segments, num_frames)
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
# Portuguese example:
# ~~~~~~~~~~~~~~~~~~~
text_raw = (
"mas na imensa extensão onde se esconde o inconsciente imortal só me responde um bramido um queixume e nada mais"
)
text_normalized = (
"mas na imensa extensao onde se esconde o inconsciente imortal so me responde um bramido um queixume e nada mais"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
......@@ -688,7 +526,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
......@@ -710,94 +547,38 @@ display_segment(6, waveform, word_segments, num_frames)
display_segment(7, waveform, word_segments, num_frames)
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
# Portuguese
# ~~~~~~~~~~
######################################################################
#
display_segment(13, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
######################################################################
#
text_raw = "na imensa extensão onde se esconde o inconsciente imortal"
text_normalized = "na imensa extensao onde se esconde o inconsciente imortal"
display_segment(16, waveform, word_segments, num_frames)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
######################################################################
#
display_segment(17, waveform, word_segments, num_frames)
waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE))
######################################################################
#
display_segment(18, waveform, word_segments, num_frames)
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(19, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
# Italian example:
# ~~~~~~~~~~~~~~~~
text_raw = "elle giacean per terra tutte quante fuor d'una ch'a seder si levò ratto ch'ella ci vide passarsi davante"
text_normalized = (
"elle giacean per terra tutte quante fuor d'una ch'a seder si levo ratto ch'ella ci vide passarsi davante"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
......@@ -813,7 +594,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
......@@ -840,50 +620,62 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
# Italian
# ~~~~~~~
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
text_raw = "elle giacean per terra tutte quante"
text_normalized = "elle giacean per terra tutte quante"
display_segment(9, waveform, word_segments, num_frames)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(13, waveform, word_segments, num_frames)
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(16, waveform, word_segments, num_frames)
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(17, waveform, word_segments, num_frames)
display_segment(5, waveform, word_segments, num_frames)
######################################################################
# Conclusion
......@@ -894,7 +686,6 @@ display_segment(17, waveform, word_segments, num_frames)
# speech data to transcripts in five languages.
#
######################################################################
# Acknowledgement
# ---------------
......
......@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
......@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach()
print(labels)
################################################################################
# Visualization
################################################################################
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
# ~~~~~~~~~~~~~
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
######################################################################
# Generate alignment probability (trellis)
......@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plt.show()
# ~~~~~~~~~~~~~
def plot():
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
######################################################################
# In the above visualization, we can see that there is a trace of high
......@@ -266,7 +271,9 @@ for p in path:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
......@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
plt.show()
######################################################################
# Looking good. Now this path contains repetations for the same labels, so
# Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
......@@ -297,7 +308,7 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
......@@ -330,7 +341,9 @@ for seg in segments:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
......@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.set_xticks([])
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3))
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
......@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
......@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim())
ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
......@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
ax1.set_yticks([])
......@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7))
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8)
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
ratio = waveform.size(0) / trellis.size(0)
......@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments(
......@@ -458,7 +475,6 @@ plot_alignments(
word_segments,
waveform[0],
)
plt.show()
################################################################################
......
......@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
_, axis = plt.subplots(1, 1)
axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
axis.set_title(title)
plt.tight_layout()
######################################################################
......@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
)
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}")
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate)
......@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
#
# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture")
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)
######################################################################
......
......@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
#
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
axis.set_title(title)
plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None):
def plot_mask(mask, title="Mask"):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title)
axis.set_title(title)
plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8):
......
......@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os
import time
import matplotlib
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
######################################################################
#
# Check the prerequisites
......
......@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
fig.tight_layout()
######################################################################
......@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
plt.tight_layout()
print("Class labels:", bundle.get_labels())
......
......@@ -82,19 +82,23 @@ try:
from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
import google.colab # noqa: F401
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install pesq
!pip3 install pystoi
"""
)
try:
import google.colab # noqa: F401
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install pesq
!pip3 install pystoi
"""
)
except Exception:
pass
raise
import matplotlib.pyplot as plt
......@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item()
def plot_waveform(waveform, title):
def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy()
sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.plot(time_axis, wav_numpy[0], linewidth=1)
axes.grid(True)
figure.suptitle(title)
plt.show(block=False)
time_axis = torch.arange(0, sample_size) / sample_rate
def plot_specgram(waveform, sample_rate, title):
wav_numpy = waveform.numpy()
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure, axes = plt.subplots(2, 1)
axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
axes[0].grid(True)
axes[1].specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample
#
plot_waveform(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
plot(WAVEFORM_SPEECH, "Clean Speech")
######################################################################
# Visualize noise sample
#
plot_waveform(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
plot(WAVEFORM_NOISE, "Noise")
######################################################################
# Visualize distorted speech with 20dB SNR
#
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
######################################################################
# Visualize distorted speech with -5dB SNR
#
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment