"examples/vscode:/vscode.git/clone" did not exist on "40b3cdf79ea90d20b8adbdea330e3af60d5522bd"
Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
...@@ -127,6 +127,22 @@ def _get_pattern(): ...@@ -127,6 +127,22 @@ def _get_pattern():
return ret return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = { sphinx_gallery_conf = {
"examples_dirs": [ "examples_dirs": [
"../../examples/tutorials", "../../examples/tutorials",
...@@ -139,6 +155,7 @@ sphinx_gallery_conf = { ...@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True, "promote_jupyter_magic": True,
"first_notebook_cell": None, "first_notebook_cell": None,
"doc_module": ("torchaudio",), "doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
} }
autosummary_generate = True autosummary_generate = True
......
...@@ -71,8 +71,8 @@ model implementations and application components. ...@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial tutorials/online_asr_tutorial
tutorials/device_asr tutorials/device_asr
tutorials/device_avsr tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial tutorials/hybrid_demucs_tutorial
...@@ -147,6 +147,13 @@ Tutorials ...@@ -147,6 +147,13 @@ Tutorials
.. customcardstart:: .. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem:: .. customcarditem::
:header: Loading waveform Tensors from files and saving them :header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions. :card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE) ...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
# #
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True) fig, axes = plt.subplots(4, 1, sharex=True)
...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): ...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4): for i in range(4):
axes[i].grid(True) axes[i].grid(True)
pos = axes[2].get_position() pos = axes[2].get_position()
plt.tight_layout() fig.tight_layout()
if zoom is not None: if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0]) ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# #
...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES) ...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1) freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Square wave # Square wave
...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Triangle wave # Triangle wave
...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
# #
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Inharmonic Paritials # Inharmonic Paritials
...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1) ...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
###################################################################### ######################################################################
# #
...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) ...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones) freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE) plot(freq, amp, waveform, SAMPLE_RATE)
###################################################################### ######################################################################
# References # References
......
...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0]) ...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
# #
def plot_alignments(waveform, emission, tokens, timesteps): def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
fig, ax = plt.subplots(figsize=(32, 10))
t = torch.arange(waveform.size(0)) / sample_rate
ax.plot(waveform) ratio = waveform.size(0) / emission.size(1) / sample_rate
ratio = waveform.shape[0] / emission.shape[1] chars = []
word_start = 0 words = []
word_start = None
for i in range(len(tokens)): for token, timestep in zip(tokens, timesteps * ratio):
if i != 0 and tokens[i - 1] == "|": if token == "|":
word_start = timesteps[i] if word_start is not None:
if tokens[i] != "|": words.append((word_start, timestep))
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) word_start = None
elif i != 0: else:
word_end = timesteps[i] chars.append((token, timestep))
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") if word_start is None:
word_start = timestep
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate) fig, axes = plt.subplots(3, 1)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0]) def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
plot_alignments(waveform[0], emission, predicted_tokens, timesteps) ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
###################################################################### ######################################################################
......
...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None): ...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): ...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
# -*- coding: utf-8 -*-
""" """
Audio Datasets Audio Datasets
============== ==============
...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of ...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets. available datasets.
""" """
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch import torch
import torchaudio import torchaudio
...@@ -21,22 +16,13 @@ print(torch.__version__) ...@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os import os
import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no") ...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True) os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, _ = waveform.shape figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) figure.tight_layout()
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
###################################################################### ######################################################################
...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate): ...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset. # :py:class:`torchaudio.datasets.YESNO` dataset.
# #
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True) dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]: ######################################################################
waveform, sample_rate, label = dataset[i] #
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") i = 1
play_audio(waveform, sample_rate) waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
...@@ -19,25 +19,19 @@ print(torch.__version__) ...@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section) # Preparation
# -------------------------------------------------------- # -----------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -75,16 +69,9 @@ def get_spectrogram( ...@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform) return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): def plot_spec(ax, spec, title, ylabel="freq_bin"):
fig, axs = plt.subplots(1, 1) ax.set_title(title)
axs.set_title(title or "Spectrogram (db)") ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No ...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None) spec = get_spectrogram(power=None)
stretch = T.TimeStretch() stretch = T.TimeStretch()
rate = 1.2 spec_12 = stretch(spec, overriding_rate=1.2)
spec_ = stretch(spec, rate) spec_09 = stretch(spec, overriding_rate=0.9)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
###################################################################### ######################################################################
# TimeMasking
# -----------
# #
torch.random.manual_seed(4)
spec = get_spectrogram() def plot():
plot_spectrogram(spec[0], title="Original") fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis") plot()
###################################################################### ######################################################################
# FrequencyMasking # Time and Frequency Masking
# ---------------- # --------------------------
# #
torch.random.manual_seed(4) torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram() spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original") time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis") plot()
...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None): ...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True) ax.grid(True)
ax.set_xlim([0, time_axis[-1]]) ax.set_xlim([0, time_axis[-1]])
ax.set_title(title) ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): ...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title) ax.set_title(title)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None): def plot_fbank(fbank, title=None):
...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None): ...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto") axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin") axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin") axs.set_xlabel("mel bin")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch): ...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0) axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate): ...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform") figure.suptitle("waveform")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"): ...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
...@@ -105,7 +105,6 @@ def plot_sweep( ...@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67) axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax) plt.colorbar(cax)
plt.show(block=True)
###################################################################### ######################################################################
......
...@@ -5,254 +5,277 @@ CTC forced alignment API tutorial ...@@ -5,254 +5,277 @@ CTC forced alignment API tutorial
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__ **Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__
This tutorial shows how to align transcripts to speech with This tutorial shows how to align transcripts to speech using
``torchaudio``'s CTC forced alignment API proposed in the paper :py:func:`torchaudio.functional.forced_align`
`“Scaling Speech Technology to 1,000+ which was developed along the work of
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__, `Scaling Speech Technology to 1,000+ Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
and one advanced usage, i.e. dealing with transcription errors with a <star> token.
The forced alignment is a process to align transcript with speech.
Though there’s some overlap in visualization We cover the basics of forced alignment in `Forced Alignment with
diagrams, the scope here is different from the `“Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified
Wav2Vec2” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__ step-by-step Python implementations.
tutorial, which focuses on a step-by-step demonstration of the forced
alignment generation algorithm (without using an API) described in the :py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
`paper <https://arxiv.org/abs/2007.09127>`__ with a Wav2Vec2 model. implementations which are more performant than the vanilla Python
implementation above, and are more accurate.
It can also handle missing transcript with special <star> token.
For examples of aligning multiple languages, please refer to
`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__.
""" """
import torch import torch
import torchaudio import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
######################################################################
#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from dataclasses import dataclass
print(device) from typing import List
try: import IPython
from torchaudio.functional import forced_align import matplotlib.pyplot as plt
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
###################################################################### ######################################################################
# Basic usages
# ------------
#
# In this section, we cover the following content:
#
# 1. Generate frame-wise class probabilites from audio waveform from a CTC
# acoustic model.
# 2. Compute frame-level alignments using TorchAudio’s forced alignment
# API.
# 3. Obtain token-level alignments from frame-level alignments.
# 4. Obtain word-level alignments from token-level alignments.
# #
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
###################################################################### ######################################################################
# Preparation # Preparation
# ~~~~~~~~~~~ # -----------
# #
# First we import the necessary packages, and fetch data that we work on. # First we prepare the speech data and the transcript we area going
# to use.
# #
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
sample_rate = 16000 TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
###################################################################### ######################################################################
# Generate frame-wise class posteriors from a CTC acoustic model # Generating emissions and tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`~torchaudio.functional.forced_align` takes emission and
# token sequences and outputs timestaps of the tokens and their scores.
# #
# The first step is to generate the class probabilities (i.e. posteriors) # Emission reperesents the frame-wise probability distribution over
# of each audio frame using a CTC model. # tokens, and it can be obtained by passing waveform to an acoustic
# Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`. # model.
# Tokens are numerical expression of transcripts. It can be obtained by
# simply mapping each character to the index of token list.
# The emission and the token sequences must be using the same set of tokens.
#
# We can use pre-trained Wav2Vec2 model to obtain emission from speech,
# and map transcript to tokens.
# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`,
# which bandles pre-trained model weights with associated labels.
# #
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device) model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode(): with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE) waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device)) emission, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1) emission = torch.log_softmax(emission, dim=-1)
######################################################################
#
emission = emissions.cpu().detach() def plot_emission(emission):
dictionary = {c: i for i, c in enumerate(labels)} plt.imshow(emission.cpu().T)
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.tight_layout()
print(dictionary)
plot_emission(emission[0])
###################################################################### ######################################################################
# Visualization # We create a dictionary, which maps each label into token.
# ^^^^^^^^^^^^^
# labels = bundle.get_labels()
DICTIONARY = {c: i for i, c in enumerate(labels)}
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
######################################################################
# converting transcript to tokens is as simple as
plt.imshow(emission[0].T) tokens = [DICTIONARY[c] for c in TRANSCRIPT]
plt.colorbar()
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
print(" ".join(str(t) for t in tokens))
###################################################################### ######################################################################
# Computing frame-level alignments # Computing frame-level alignments
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # --------------------------------
# #
# Then we call TorchAudio’s forced alignment API to compute the # Now we call TorchAudio’s forced alignment API to compute the
# frame-level alignment between each audio frame and each token in the # frame-level alignment. For the detail of function signature, please
# transcript. We first explain the inputs and outputs of the API # refer to :py:func:`~torchaudio.functional.forced_align`.
# ``functional.forced_align``. Note that this API works on both CPU and
# GPU. In the current tutorial we demonstrate it on CPU.
# #
# **Inputs**:
# #
# ``emission``: a 2D tensor of size :math:`T \times N`, where :math:`T` is
# the number of frames (after sub-sampling by the acoustic model, if any),
# and :math:`N` is the vocabulary size. def align(emission, tokens):
# alignments, scores = forced_align(
# ``targets``: a 1D tensor vector of size :math:`M`, where :math:`M` is emission,
# the length of the transcript, and each element is a token ID looked up targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
# from the vocabulary. For example, the ``targets`` tensor repsenting the input_lengths=torch.tensor([emission.size(1)], device=emission.device),
# transcript “i had…” is :math:`[5, 18, 4, 16, ...]`. target_lengths=torch.tensor([len(tokens)], device=emission.device),
# blank=0,
# ``input lengths``: :math:`T`. )
#
# ``target lengths``: :math:`M`. scores = scores.exp() # convert back to probability
# alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
# **Outputs**: return alignments.tolist(), scores.tolist()
#
# ``frame_alignment``: a 1D tensor of size :math:`T` storing the aligned
# token index (looked up from the vocabulary) of each frame, e.g. for the frame_alignment, frame_scores = align(emission, tokens)
# segment corresponding to “i had” in the given example , the
# frame_alignment is ######################################################################
# :math:`[...0, 0, 5, 0, 0, 18, 18, 4, 0, 0, 0, 16,...]`, where :math:`0` # Now let's look at the output.
# represents the blank symbol. # Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}")
######################################################################
# #
# ``frame_scores``: a 1D tensor of size :math:`T` storing the confidence # The ``Frame`` instance represents the most likely token at each frame
# score (0 to 1) for each each frame. For each frame, the score should be # with its confidence.
# close to one if the alignment quality is good. #
# When interpreting it, one must remember that the meaning of blank token
# and repeated token are context dependent.
#
# .. note::
#
# When same token occured after blank tokens, it is not treated as
# a repeat, but as a new occurrence.
#
# .. code-block::
#
# a a a b -> a b
# a - - b -> a b
# a a - b -> a b
# a - a b -> a a b
# ^^^ ^^^
#
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # Start of "I"
# 31: 0 [-], 0.98 # repeat (blank token)
# 32: 0 [-], 1.00 # repeat (blank token)
# 33: 1 [|], 0.85 # Start of "|" (word boundary)
# 34: 1 [|], 1.00 # repeat (same token)
# 35: 0 [-], 0.61 # repeat (blank token)
# 36: 8 [H], 1.00 # Start of "H"
# 37: 0 [-], 1.00 # repeat (blank token)
# 38: 4 [A], 1.00 # Start of "A"
# 39: 0 [-], 0.99 # repeat (blank token)
# 40: 11 [D], 0.92 # Start of "D"
# 41: 0 [-], 0.93 # repeat (blank token)
# 42: 1 [|], 0.98 # Start of "|"
# 43: 1 [|], 1.00 # repeat (same token)
# 44: 3 [T], 1.00 # Start of "T"
# 45: 3 [T], 0.90 # repeat (same token)
# 46: 8 [H], 1.00 # Start of "H"
# 47: 0 [-], 1.00 # repeat (blank token)
######################################################################
# Resolve blank and repeated tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# From the outputs ``frame_alignment`` and ``frame_scores``, we generate a # Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments.
# From the outputs ``alignment`` and ``scores``, we generate a
# list called ``frames`` storing information of all frames aligned to # list called ``frames`` storing information of all frames aligned to
# non-blank tokens. Each element contains 1) token_index: the aligned # non-blank tokens.
# token’s index in the transcript 2) time_index: the current frame’s index
# in the input audio (or more precisely, the row dimension of the emission
# matrix) 3) the confidence scores of the current frame.
#
# For the given example, the first few elements of the list ``frames``
# corresponding to “i had” looks as the following:
#
# ``Frame(token_index=0, time_index=32, score=0.9994410872459412)``
# #
# ``Frame(token_index=1, time_index=35, score=0.9980823993682861)`` # Each element contains the following
# #
# ``Frame(token_index=1, time_index=36, score=0.9295750260353088)`` # - ``token_index``: the aligned token’s index **in the transcript**
# # - ``time_index``: the current frame’s index in emission
# ``Frame(token_index=2, time_index=37, score=0.9997448325157166)`` # - ``score``: scores of the current frame.
#
# ``Frame(token_index=3, time_index=41, score=0.9991760849952698)``
#
# ``...``
#
# The interpretation is:
#
# The token with index :math:`0` in the transcript, i.e. “i”, is aligned
# to the :math:`32`\ th audio frame, with confidence :math:`0.9994`. The
# token with index :math:`1` in the transcript, i.e. “h”, is aligned to
# the :math:`35`\ th and :math:`36`\ th audio frames, with confidence
# :math:`0.9981` and :math:`0.9296` respectively. The token with index
# :math:`2` in the transcript, i.e. “a”, is aligned to the :math:`35`\ th
# and :math:`36`\ th audio frames, with confidence :math:`0.9997`. The
# token with index :math:`3` in the transcript, i.e. “d”, is aligned to
# the :math:`41`\ th audio frame, with confidence :math:`0.9992`.
#
# From such information stored in the ``frames`` list, we’ll compute
# token-level and word-level alignments easily.
# #
# ``token_index`` is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
@dataclass @dataclass
class Frame: class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
def compute_alignments(transcript, dictionary, emission): ######################################################################
frames = [] #
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert frame_alignment.shape[1] == input_lengths[0].item() def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
assert targets.shape[1] == target_lengths[0].item() assert len(alignments) == len(scores)
token_index = -1 token_index = -1
prev_hyp = 0 prev_hyp = 0
for i in range(frame_alignment.shape[1]): frames = []
if frame_alignment[0][i].item() == 0: for i, (ali, score) in enumerate(zip(alignments, scores)):
if ali == 0:
prev_hyp = 0 prev_hyp = 0
continue continue
if frame_alignment[0][i].item() != prev_hyp: if ali != prev_hyp:
token_index += 1 token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item())) frames.append(Frame(token_index, i, score))
prev_hyp = frame_alignment[0][i].item() prev_hyp = ali
return frames, frame_alignment, frame_scores return frames
transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
frames, frame_alignment, frame_scores = compute_alignments(transcript, dictionary, emission)
###################################################################### ######################################################################
# Obtain token-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
frames = obtain_token_level_alignments(frame_alignment, frame_scores)
print("Time\tLabel\tScore")
for f in frames:
print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
###################################################################### ######################################################################
# Obtain token-level alignments and confidence scores
# ---------------------------------------------------
#
# The frame-level alignments contains repetations for the same labels. # The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned # Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information, # frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks # while being more convenient to apply to some downstream tasks
# (e.g. computing word-level alignments). # (e.g. computing word-level alignments).
# #
# Now we demonstrate how to obtain token-level alignments and confidence # Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging # scores by simply merging frame-level alignments and averaging
# frame-level confidence scores. # frame-level confidence scores.
# #
######################################################################
# The following class represents the label, its score and the time span
# of its occurance.
#
# Merge the labels
@dataclass @dataclass
class Segment: class Segment:
label: str label: str
...@@ -261,13 +284,16 @@ class Segment: ...@@ -261,13 +284,16 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property def __len__(self):
def length(self):
return self.end - self.start return self.end - self.start
######################################################################
#
def merge_repeats(frames, transcript): def merge_repeats(frames, transcript):
transcript_nospace = transcript.replace(" ", "") transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0 i1, i2 = 0, 0
...@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript): ...@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript):
return segments return segments
segments = merge_repeats(frames, transcript) ######################################################################
#
segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments: for seg in segments:
print(seg) print(seg)
###################################################################### ######################################################################
# Visualization # Visualization
# ^^^^^^^^^^^^^ # ~~~~~~~~~~~~~
# #
def plot_label_prob(segments, transcript): def plot_label_prob(segments, transcript):
fig, ax2 = plt.subplots(figsize=(16, 4)) fig, ax = plt.subplots()
ax2.set_title("frame-level and token-level confidence scores") ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
for p in frames: for p in frames:
...@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript): ...@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript):
xs.append(p.time_index + 1) xs.append(p.time_index + 1)
hs.append(p.score) hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black") ax.set_ylim(-0.1, 1.1)
ax2.set_ylim(-0.1, 1.1) ax.grid(True, axis="y")
fig.tight_layout()
plot_label_prob(segments, transcript) plot_label_prob(segments, TRANSCRIPT)
plt.tight_layout()
plt.show()
###################################################################### ######################################################################
# From the visualized scores, we can see that, for tokens spanning over # From the visualized scores, we can see that, for tokens spanning over
# more multiple frames, e.g. “T” in “THAT, the token-level confidence # more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this # score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was # clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in # plotted in the”Label probability with and without repeatation” figure in
# the previous tutorial `“Forced Alignment with # the previous tutorial
# Wav2Vec2 <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__. # `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
# #
######################################################################
# Obtain word-level alignments and confidence scores # Obtain word-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # --------------------------------------------------
# #
...@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "): ...@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "):
s = 0 s = 0
segs = segments[i1 + s : i2 + s] segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs]) word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score)) words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2 i1 = i2
else: else:
...@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "): ...@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "):
return words return words
word_segments = merge_words(transcript, segments, "|") word_segments = merge_words(TRANSCRIPT, segments, "|")
###################################################################### ######################################################################
# Visualization # Visualization
# ^^^^^^^^^^^^^ # ~~~~~~~~~~~~~
# #
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10): def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate):
fig, ax2 = plt.subplots(figsize=(64, 12)) fig, ax = plt.subplots()
plt.rcParams.update({"font.size": 30})
# The original waveform ax.specgram(waveform[0], Fs=sample_rate)
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
# The original waveform
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments: for word in word_segments:
x0 = ratio * word.start t0, t1 = ratio * word.start, ratio * word.end
x1 = ratio * word.end ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax2.axvspan(x0, x1, alpha=0.1, color="red") ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale)) ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks() ax.set_xlabel("time [second]")
plt.xticks(xticks, xticks / sample_rate, fontsize=50) fig.tight_layout()
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
plot_alignments( plot_alignments(waveform, emission, segments, word_segments)
segments,
word_segments,
waveform,
emission.shape[1],
1,
)
plt.show()
###################################################################### ######################################################################
# A trick to embed the resulting audio to the generated file. def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate):
# `IPython.display.Audio` has to be the last call in a cell, ratio = waveform.size(1) / len(frame_alignment)
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, frame_alignment):
ratio = waveform.size(1) / frame_alignment.size(1)
word = word_segments[i] word = word_segments[i]
x0 = int(ratio * word.start) x0 = int(ratio * word.start)
x1 = int(ratio * word.end) x1 = int(ratio * word.end)
...@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment): ...@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment):
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
######################################################################
# Generate the audio for each segment # Generate the audio for each segment
print(transcript) print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE) IPython.display.Audio(SPEECH_FILE)
###################################################################### ######################################################################
...@@ -488,62 +503,71 @@ display_segment(8, waveform, word_segments, frame_alignment) ...@@ -488,62 +503,71 @@ display_segment(8, waveform, word_segments, frame_alignment)
###################################################################### ######################################################################
# Advanced usage: Dealing with missing transcripts using the <star> token # Advanced: Handling transcripts with ``<star>`` token
# --------------------------------------------------------------------------- # ----------------------------------------------------
# #
# Now let’s look at when the transcript is partially missing, how can we # Now let’s look at when the transcript is partially missing, how can we
# improve alignment quality using the <star> token, which is capable of modeling # improve alignment quality using the ``<star>`` token, which is capable of modeling
# any token. # any token.
# #
# Here we use the same English example as used above. But we remove the # Here we use the same English example as used above. But we remove the
# beginning text “i had that curiosity beside me at” from the transcript. # beginning text ``“i had that curiosity beside me at”`` from the transcript.
# Aligning audio with such transcript results in wrong alignments of the # Aligning audio with such transcript results in wrong alignments of the
# existing word “this”. However, this issue can be mitigated by using the # existing word “this”. However, this issue can be mitigated by using the
# <star> token to model the missing text. # ``<star>`` token to model the missing text.
# #
# Reload the emission tensor in order to add the extra dimension corresponding to the <star> token. ######################################################################
with torch.inference_mode(): # First, we extend the dictionary to include the ``<star>`` token.
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
# Append the extra dimension corresponding to the <star> token DICTIONARY["*"] = len(DICTIONARY)
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
# Extend the dictionary to include the <star> token. ######################################################################
dictionary["*"] = 29 # Next, we extend the emission tensor with the extra dimension
# corresponding to the ``<star>`` token.
#
assert len(dictionary) == emission.shape[2] extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device)
emission = torch.cat((emission, extra_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
######################################################################
# The following function combines all the processes, and compute
# word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform): def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission) tokens = [dictionary[c] for c in transcript]
alignment, scores = align(emission, tokens)
frames = obtain_token_level_alignments(alignment, scores)
segments = merge_repeats(frames, transcript) segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|") word_segments = merge_words(transcript, segments, "|")
plot_alignments(segments, word_segments, waveform, emission.shape[1], 1) plot_alignments(waveform, emission, segments, word_segments)
plt.show() plt.xlim([0, None])
return word_segments, frame_alignment
# original:
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
###################################################################### ######################################################################
# **Original**
# Demonstrate the effect of <star> token for dealing with deletion errors compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform)
# ("i had that curiosity beside me at" missing from the transcript):
transcript = "THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
###################################################################### ######################################################################
# **With <star> token**
#
# Now we replace the first part of the transcript with the ``<star>`` token.
# Replacing the missing transcript with the <star> token: compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
transcript = "*|THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform) ######################################################################
# **Without <star> token**
#
# As a comparison, the following aligns the partial transcript
# without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform)
###################################################################### ######################################################################
# Conclusion # Conclusion
...@@ -551,7 +575,7 @@ word_segments, frame_alignment = compute_and_plot_alignments(transcript, diction ...@@ -551,7 +575,7 @@ word_segments, frame_alignment = compute_and_plot_alignments(transcript, diction
# #
# In this tutorial, we looked at how to use torchaudio’s forced alignment # In this tutorial, we looked at how to use torchaudio’s forced alignment
# API to align and segment speech files, and demonstrated one advanced usage: # API to align and segment speech files, and demonstrated one advanced usage:
# How introducing a <star> token could improve alignment accuracy when # How introducing a ``<star>`` token could improve alignment accuracy when
# transcription errors exist. # transcription errors exist.
# #
......
...@@ -69,7 +69,7 @@ import torchvision ...@@ -69,7 +69,7 @@ import torchvision
# ------------------- # -------------------
# #
# Firstly, we define the function to collect videos from microphone and # Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader` # camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing # class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this # audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the # class, please refer to the
......
...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape num_filts, window_size = irs.shape
half = window_size // 2 half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size) t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors): for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}") ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)]) axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors): for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True) ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}") ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n" "Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048): ...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True) axes[i].grid(True)
axes[1].set(title="Frequency Response") axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency") axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right") axes[2].legend(loc="center right")
fig.tight_layout() fig.tight_layout()
......
...@@ -6,15 +6,14 @@ Forced alignment for multilingual data ...@@ -6,15 +6,14 @@ Forced alignment for multilingual data
This tutorial shows how to compute forced alignments for speech data This tutorial shows how to compute forced alignments for speech data
from multiple non-English languages using ``torchaudio``'s CTC forced alignment from multiple non-English languages using ``torchaudio``'s CTC forced alignment
API described in `“CTC forced alignment API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__
tutorial” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__ and the multilingual Wav2vec2 model proposed in the paper `Scaling
and the multilingual Wav2vec2 model proposed in the paper `“Scaling
Speech Technology to 1,000+ Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__. Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The model was trained on 23K of audio data from 1100+ languages using The model was trained on 23K of audio data from 1100+ languages using
the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__ the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__
as targets. as targets.
""" """
import torch import torch
...@@ -23,53 +22,46 @@ import torchaudio ...@@ -23,53 +22,46 @@ import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) print(device)
try:
from torchaudio.functional import forced_align
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
# Here we import necessary packages, and define utility functions for
# computing the frame-level alignments (using the API
# ``functional.forced_align``), token-level and word-level alignments, and
# also alignment visualization utilities.
#
# %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
sample_rate = 16000 ######################################################################
#
SAMPLE_RATE = 16000
######################################################################
#
# Here we define utility functions for computing the frame-level
# alignments (using the API :py:func:`torchaudio.functional.forced_align`),
# token-level and word-level alignments.
# For the detail of these functions please refer to
# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__.
#
@dataclass @dataclass
class Frame: class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
######################################################################
#
@dataclass @dataclass
class Segment: class Segment:
label: str label: str
...@@ -78,39 +70,42 @@ class Segment: ...@@ -78,39 +70,42 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property def __len__(self):
def length(self):
return self.end - self.start return self.end - self.start
# compute frame-level and word-level alignments using torchaudio's forced alignment API ######################################################################
#
def compute_alignments(transcript, dictionary, emission): def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")] tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0) targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]]) input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]]) target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
# This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments. scores = scores.exp() # convert back to probability
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0) alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert frame_alignment.shape[1] == input_lengths[0].item() assert len(alignment) == len(scores) == emission.size(1)
assert targets.shape[1] == target_lengths[0].item()
token_index = -1 token_index = -1
prev_hyp = 0 prev_hyp = 0
for i in range(frame_alignment.shape[1]): frames = []
if frame_alignment[0][i].item() == 0: for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0 prev_hyp = 0
continue continue
if frame_alignment[0][i].item() != prev_hyp: if ali != prev_hyp:
token_index += 1 token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item())) frames.append(Frame(token_index, i, score))
prev_hyp = frame_alignment[0][i].item() prev_hyp = ali
# compute frame alignments from token alignments # compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "") transcript_nospace = transcript.replace(" ", "")
...@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission): ...@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission):
if i1 != i2: if i1 != i2:
if i3 == len(transcript) - 1: if i3 == len(transcript) - 1:
i2 += 1 i2 += 1
s = 0 segs = segments[i1:i2]
segs = segments[i1 + s : i2 + s] word = "".join([s.label for s in segs])
word = "".join([seg.label for seg in segs]) score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2 i1 = i2
else: else:
i2 += 1 i2 += 1
i3 += 1 i3 += 1
return segments, words
num_frames = frame_alignment.shape[1]
return segments, words, num_frames
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
# utility function for plotting word alignments
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({"font.size": 30})
# The original waveform ######################################################################
ratio = waveform.size(1) / input_lengths #
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale) # utility function for plotting word alignments
ax2.set_xlim(0, waveform.size(-1)) def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim()
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments: for word in word_segments:
x0 = ratio * word.start t0, t1 = word.start * ratio, word.end * ratio
x1 = ratio * word.end ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax2.axvspan(x0, x1, alpha=0.1, color="red") ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale)) ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks() ax.set_xlabel("time [second]")
plt.xticks(xticks, xticks / sample_rate, fontsize=50) ax.set_xlim(xlim)
ax2.set_xlabel("time [second]", fontsize=40) fig.tight_layout()
ax2.set_yticks([])
return IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
# utility function for playing audio segments. # utility function for playing audio segments.
# A trick to embed the resulting audio to the generated file. def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, num_frames):
ratio = waveform.size(1) / num_frames ratio = waveform.size(1) / num_frames
word = word_segments[i] word = word_segments[i]
x0 = int(ratio * word.start) x0 = int(ratio * word.start)
...@@ -241,26 +243,21 @@ model.load_state_dict( ...@@ -241,26 +243,21 @@ model.load_state_dict(
) )
) )
model.eval() model.eval()
model.to(device)
def get_emission(waveform): def get_emission(waveform):
with torch.inference_mode():
# NOTE: this step is essential # NOTE: this step is essential
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
emission, _ = model(waveform)
emissions, _ = model(waveform) return torch.log_softmax(emission, dim=-1)
emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions.cpu().detach()
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
return emission, waveform
# Construct the dictionary # Construct the dictionary
# '@' represents the OOV token, '*' represents the <star> token. # '@' represents the OOV token
# <pad> and </s> are fairseq's legacy tokens, which're not used. # <pad> and </s> are fairseq's legacy tokens, which're not used.
# <star> token is omitted as we do not use it in this tutorial
dictionary = { dictionary = {
"<blank>": 0, "<blank>": 0,
"<pad>": 1, "<pad>": 1,
...@@ -293,7 +290,6 @@ dictionary = { ...@@ -293,7 +290,6 @@ dictionary = {
"'": 28, "'": 28,
"q": 29, "q": 29,
"x": 30, "x": 30,
"*": 31,
} }
...@@ -304,11 +300,8 @@ dictionary = { ...@@ -304,11 +300,8 @@ dictionary = {
# romanizer and using it to obtain romanized transcripts, and PyThon # romanizer and using it to obtain romanized transcripts, and PyThon
# commands required for further normalizing the romanized transcript. # commands required for further normalizing the romanized transcript.
# #
# %%
# .. code-block:: bash # .. code-block:: bash
# #
# %%bash
# Save the raw transcript to a file # Save the raw transcript to a file
# echo 'raw text' > text.txt # echo 'raw text' > text.txt
# git clone https://github.com/isi-nlp/uroman # git clone https://github.com/isi-nlp/uroman
...@@ -334,141 +327,77 @@ dictionary = { ...@@ -334,141 +327,77 @@ dictionary = {
###################################################################### ######################################################################
# German example: # German
# ~~~~~~~~~~~~~~~~ # ~~~~~~
text_raw = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
text_normalized = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False) speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
assert len(dictionary) == emission.shape[2]
transcript = text_normalized text_raw = "aber seit ich bei ihnen das brot hole"
text_normalized = "aber seit ich bei ihnen das brot hole"
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized) print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
###################################################################### plot_alignments(waveform, emission, segments, word_segments)
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(13, waveform, word_segments, num_frames) display_segment(1, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(14, waveform, word_segments, num_frames) display_segment(2, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(15, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(16, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(17, waveform, word_segments, num_frames) display_segment(4, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(18, waveform, word_segments, num_frames) display_segment(5, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(19, waveform, word_segments, num_frames) display_segment(6, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(20, waveform, word_segments, num_frames) display_segment(7, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# Chinese example: # Chinese
# ~~~~~~~~~~~~~~~~ # ~~~~~~~
# #
# Chinese is a character-based language, and there is not explicit word-level # Chinese is a character-based language, and there is not explicit word-level
# tokenization (separated by spaces) in its raw written form. In order to # tokenization (separated by spaces) in its raw written form. In order to
...@@ -478,98 +407,36 @@ display_segment(20, waveform, word_segments, num_frames) ...@@ -478,98 +407,36 @@ display_segment(20, waveform, word_segments, num_frames)
# However this is not needed if you only want character-level alignments. # However this is not needed if you only want character-level alignments.
# #
text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False) speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False)
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
plot_alignments(segments, word_segments, waveform, emission.shape[1]) text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
print("Raw Transcript: ", text_raw) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized) print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(7, waveform, word_segments, num_frames) emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
# Polish example:
# ~~~~~~~~~~~~~~~
text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę dlaczego mi nie powiedziałeś szepnąłem ze łzami"
text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane dlaczego mi nie powiedziales szepnalem ze lzami"
speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -585,7 +452,6 @@ display_segment(2, waveform, word_segments, num_frames) ...@@ -585,7 +452,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -611,68 +477,40 @@ display_segment(7, waveform, word_segments, num_frames) ...@@ -611,68 +477,40 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames) display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# # Polish
# ~~~~~~
display_segment(10, waveform, word_segments, num_frames) speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
###################################################################### text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę"
# text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane"
display_segment(11, waveform, word_segments, num_frames) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
###################################################################### ######################################################################
# #
display_segment(12, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
######################################################################
#
display_segment(13, waveform, word_segments, num_frames) emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(14, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
# Portuguese example:
# ~~~~~~~~~~~~~~~~~~~
text_raw = (
"mas na imensa extensão onde se esconde o inconsciente imortal só me responde um bramido um queixume e nada mais"
)
text_normalized = (
"mas na imensa extensao onde se esconde o inconsciente imortal so me responde um bramido um queixume e nada mais"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -688,7 +526,6 @@ display_segment(2, waveform, word_segments, num_frames) ...@@ -688,7 +526,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -710,94 +547,38 @@ display_segment(6, waveform, word_segments, num_frames) ...@@ -710,94 +547,38 @@ display_segment(6, waveform, word_segments, num_frames)
display_segment(7, waveform, word_segments, num_frames) display_segment(7, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# # Portuguese
# ~~~~~~~~~~
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
###################################################################### speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
#
display_segment(13, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
###################################################################### text_raw = "na imensa extensão onde se esconde o inconsciente imortal"
# text_normalized = "na imensa extensao onde se esconde o inconsciente imortal"
display_segment(16, waveform, word_segments, num_frames) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
###################################################################### ######################################################################
# #
display_segment(17, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE))
###################################################################### emission = get_emission(waveform.to(device))
# num_frames = emission.size(1)
plot_emission(emission[0].cpu())
display_segment(18, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(19, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
# Italian example:
# ~~~~~~~~~~~~~~~~
text_raw = "elle giacean per terra tutte quante fuor d'una ch'a seder si levò ratto ch'ella ci vide passarsi davante"
text_normalized = (
"elle giacean per terra tutte quante fuor d'una ch'a seder si levo ratto ch'ella ci vide passarsi davante"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -813,7 +594,6 @@ display_segment(2, waveform, word_segments, num_frames) ...@@ -813,7 +594,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -840,50 +620,62 @@ display_segment(7, waveform, word_segments, num_frames) ...@@ -840,50 +620,62 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames) display_segment(8, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# # Italian
# ~~~~~~~
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
text_raw = "elle giacean per terra tutte quante"
text_normalized = "elle giacean per terra tutte quante"
display_segment(9, waveform, word_segments, num_frames) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
###################################################################### ######################################################################
# #
display_segment(10, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(11, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
###################################################################### ######################################################################
# #
display_segment(12, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(13, waveform, word_segments, num_frames) display_segment(1, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(14, waveform, word_segments, num_frames) display_segment(2, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(15, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(16, waveform, word_segments, num_frames) display_segment(4, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(17, waveform, word_segments, num_frames) display_segment(5, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# Conclusion # Conclusion
...@@ -894,7 +686,6 @@ display_segment(17, waveform, word_segments, num_frames) ...@@ -894,7 +686,6 @@ display_segment(17, waveform, word_segments, num_frames)
# speech data to transcripts in five languages. # speech data to transcripts in five languages.
# #
###################################################################### ######################################################################
# Acknowledgement # Acknowledgement
# --------------- # ---------------
......
...@@ -56,16 +56,11 @@ print(device) ...@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on. # First we import the necessary packages, and fetch data that we work on.
# #
# %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -99,17 +94,22 @@ with torch.inference_mode(): ...@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
print(labels)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
###################################################################### ######################################################################
# Generate alignment probability (trellis) # Generate alignment probability (trellis)
...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens) ...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) def plot():
plt.colorbar() plt.imshow(trellis.T, origin="lower")
plt.show() plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
###################################################################### ######################################################################
# In the above visualization, we can see that there is a trace of high # In the above visualization, we can see that there is a trace of high
...@@ -266,7 +271,9 @@ for p in path: ...@@ -266,7 +271,9 @@ for p in path:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path): def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path): ...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path) plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking") plt.title("The path found by backtracking")
plt.show()
###################################################################### ######################################################################
# Looking good. Now this path contains repetations for the same labels, so # Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript. # let’s merge them to make it close to the original transcript.
# #
# When merging the multiple path points, we simply take the average # When merging the multiple path points, we simply take the average
...@@ -297,7 +308,7 @@ class Segment: ...@@ -297,7 +308,7 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property @property
def length(self): def length(self):
...@@ -330,7 +341,9 @@ for seg in segments: ...@@ -330,7 +341,9 @@ for seg in segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript): def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label") ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold") ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3)) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation") ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black") ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim()) ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1) ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript) plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
###################################################################### ######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` # Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of # as the word boundary, so we merge the segments before each occurance of
# ``'|'``. # ``'|'``.
# #
...@@ -410,16 +426,16 @@ for word in word_segments: ...@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform): def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([]) ax1.set_xticks([])
ax1.set_yticks([]) ax1.set_yticks([])
...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7)) ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform # The original waveform
ratio = waveform.size(0) / trellis.size(0) ratio = waveform.size(0) / trellis.size(0)
...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([]) ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0) ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1)) ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments( plot_alignments(
...@@ -458,7 +475,6 @@ plot_alignments( ...@@ -458,7 +475,6 @@ plot_alignments(
word_segments, word_segments,
waveform[0], waveform[0],
) )
plt.show()
################################################################################ ################################################################################
......
...@@ -162,11 +162,10 @@ def separate_sources( ...@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) _, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto") axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor ...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:", "SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(), separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
) )
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}") plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate) return Audio(predicted_source, rate=sample_rate)
...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu() ...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
# #
# Mixture Clip # Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture") plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate) Audio(mix_spec, rate=sample_rate)
###################################################################### ######################################################################
......
...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav") ...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
# #
def plot_spectrogram(stft, title="Spectrogram", xlim=None): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None): def plot_mask(mask, title="Mask"):
mask = mask.numpy() mask = mask.numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8): def si_snr(estimate, reference, epsilon=1e-8):
......
...@@ -33,12 +33,9 @@ print(torchaudio.__version__) ...@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os import os
import time import time
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.io import StreamReader from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
###################################################################### ######################################################################
# #
# Check the prerequisites # Check the prerequisites
......
...@@ -160,8 +160,7 @@ for i, feats in enumerate(features): ...@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}") ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension") ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)") ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout() fig.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest") ...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result") plt.title("Classification result")
plt.xlabel("Frame (time-axis)") plt.xlabel("Frame (time-axis)")
plt.ylabel("Class") plt.ylabel("Class")
plt.show() plt.tight_layout()
print("Class labels:", bundle.get_labels()) print("Class labels:", bundle.get_labels())
......
...@@ -82,6 +82,7 @@ try: ...@@ -82,6 +82,7 @@ try:
from pystoi import stoi from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError: except ImportError:
try:
import google.colab # noqa: F401 import google.colab # noqa: F401
print( print(
...@@ -95,6 +96,9 @@ except ImportError: ...@@ -95,6 +96,9 @@ except ImportError:
!pip3 install pystoi !pip3 install pystoi
""" """
) )
except Exception:
pass
raise
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8): ...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item() return si_snr.item()
def plot_waveform(waveform, title): def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy() wav_numpy = waveform.numpy()
sample_size = waveform.shape[1] sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000 time_axis = torch.arange(0, sample_size) / sample_rate
figure, axes = plt.subplots(1, 1) figure, axes = plt.subplots(2, 1)
axes = figure.gca() axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
axes.plot(time_axis, wav_numpy[0], linewidth=1) axes[0].grid(True)
axes.grid(True) axes[1].specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
def plot_specgram(waveform, sample_rate, title):
wav_numpy = waveform.numpy()
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000) ...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample # Visualize speech sample
# #
plot_waveform(WAVEFORM_SPEECH, "Clean Speech") plot(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
###################################################################### ######################################################################
# Visualize noise sample # Visualize noise sample
# #
plot_waveform(WAVEFORM_NOISE, "Noise") plot(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
###################################################################### ######################################################################
# Visualize distorted speech with 20dB SNR # Visualize distorted speech with 20dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR") plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
###################################################################### ######################################################################
# Visualize distorted speech with -5dB SNR # Visualize distorted speech with -5dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR") plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
###################################################################### ######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment