"docs/source/api_reference/indexing.rst" did not exist on "6ee722950ec46ebee6fa1c54c4ae9cc770ec1203"
Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
...@@ -127,6 +127,22 @@ def _get_pattern(): ...@@ -127,6 +127,22 @@ def _get_pattern():
return ret return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = { sphinx_gallery_conf = {
"examples_dirs": [ "examples_dirs": [
"../../examples/tutorials", "../../examples/tutorials",
...@@ -139,6 +155,7 @@ sphinx_gallery_conf = { ...@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True, "promote_jupyter_magic": True,
"first_notebook_cell": None, "first_notebook_cell": None,
"doc_module": ("torchaudio",), "doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
} }
autosummary_generate = True autosummary_generate = True
......
...@@ -71,8 +71,8 @@ model implementations and application components. ...@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial tutorials/online_asr_tutorial
tutorials/device_asr tutorials/device_asr
tutorials/device_avsr tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial tutorials/hybrid_demucs_tutorial
...@@ -147,6 +147,13 @@ Tutorials ...@@ -147,6 +147,13 @@ Tutorials
.. customcardstart:: .. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem:: .. customcarditem::
:header: Loading waveform Tensors from files and saving them :header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions. :card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE) ...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
# #
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True) fig, axes = plt.subplots(4, 1, sharex=True)
...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): ...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4): for i in range(4):
axes[i].grid(True) axes[i].grid(True)
pos = axes[2].get_position() pos = axes[2].get_position()
plt.tight_layout() fig.tight_layout()
if zoom is not None: if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0]) ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# #
...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES) ...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1) freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Square wave # Square wave
...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Triangle wave # Triangle wave
...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
# #
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Inharmonic Paritials # Inharmonic Paritials
...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1) ...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
###################################################################### ######################################################################
# #
...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) ...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones) freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE) plot(freq, amp, waveform, SAMPLE_RATE)
###################################################################### ######################################################################
# References # References
......
...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0]) ...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
# #
def plot_alignments(waveform, emission, tokens, timesteps): def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
fig, ax = plt.subplots(figsize=(32, 10))
t = torch.arange(waveform.size(0)) / sample_rate
ax.plot(waveform) ratio = waveform.size(0) / emission.size(1) / sample_rate
ratio = waveform.shape[0] / emission.shape[1] chars = []
word_start = 0 words = []
word_start = None
for i in range(len(tokens)): for token, timestep in zip(tokens, timesteps * ratio):
if i != 0 and tokens[i - 1] == "|": if token == "|":
word_start = timesteps[i] if word_start is not None:
if tokens[i] != "|": words.append((word_start, timestep))
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) word_start = None
elif i != 0: else:
word_end = timesteps[i] chars.append((token, timestep))
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") if word_start is None:
word_start = timestep
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate) fig, axes = plt.subplots(3, 1)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0]) def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
plot_alignments(waveform[0], emission, predicted_tokens, timesteps) ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
###################################################################### ######################################################################
......
...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None): ...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): ...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
# -*- coding: utf-8 -*-
""" """
Audio Datasets Audio Datasets
============== ==============
...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of ...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets. available datasets.
""" """
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch import torch
import torchaudio import torchaudio
...@@ -21,22 +16,13 @@ print(torch.__version__) ...@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os import os
import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no") ...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True) os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, _ = waveform.shape figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) figure.tight_layout()
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
###################################################################### ######################################################################
...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate): ...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset. # :py:class:`torchaudio.datasets.YESNO` dataset.
# #
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True) dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]: ######################################################################
waveform, sample_rate, label = dataset[i] #
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") i = 1
play_audio(waveform, sample_rate) waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
...@@ -19,25 +19,19 @@ print(torch.__version__) ...@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section) # Preparation
# -------------------------------------------------------- # -----------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -75,16 +69,9 @@ def get_spectrogram( ...@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform) return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): def plot_spec(ax, spec, title, ylabel="freq_bin"):
fig, axs = plt.subplots(1, 1) ax.set_title(title)
axs.set_title(title or "Spectrogram (db)") ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No ...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None) spec = get_spectrogram(power=None)
stretch = T.TimeStretch() stretch = T.TimeStretch()
rate = 1.2 spec_12 = stretch(spec, overriding_rate=1.2)
spec_ = stretch(spec, rate) spec_09 = stretch(spec, overriding_rate=0.9)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
###################################################################### ######################################################################
# TimeMasking
# -----------
# #
torch.random.manual_seed(4)
spec = get_spectrogram() def plot():
plot_spectrogram(spec[0], title="Original") fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis") plot()
###################################################################### ######################################################################
# FrequencyMasking # Time and Frequency Masking
# ---------------- # --------------------------
# #
torch.random.manual_seed(4) torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram() spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original") time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis") plot()
...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None): ...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True) ax.grid(True)
ax.set_xlim([0, time_axis[-1]]) ax.set_xlim([0, time_axis[-1]])
ax.set_title(title) ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): ...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title) ax.set_title(title)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None): def plot_fbank(fbank, title=None):
...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None): ...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto") axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin") axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin") axs.set_xlabel("mel bin")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch): ...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0) axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate): ...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform") figure.suptitle("waveform")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"): ...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
...@@ -105,7 +105,6 @@ def plot_sweep( ...@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67) axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax) plt.colorbar(cax)
plt.show(block=True)
###################################################################### ######################################################################
......
...@@ -5,254 +5,277 @@ CTC forced alignment API tutorial ...@@ -5,254 +5,277 @@ CTC forced alignment API tutorial
**Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__ **Author**: `Xiaohui Zhang <xiaohuizhang@meta.com>`__
This tutorial shows how to align transcripts to speech with This tutorial shows how to align transcripts to speech using
``torchaudio``'s CTC forced alignment API proposed in the paper :py:func:`torchaudio.functional.forced_align`
`“Scaling Speech Technology to 1,000+ which was developed along the work of
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__, `Scaling Speech Technology to 1,000+ Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
and one advanced usage, i.e. dealing with transcription errors with a <star> token.
The forced alignment is a process to align transcript with speech.
Though there’s some overlap in visualization We cover the basics of forced alignment in `Forced Alignment with
diagrams, the scope here is different from the `“Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified
Wav2Vec2” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__ step-by-step Python implementations.
tutorial, which focuses on a step-by-step demonstration of the forced
alignment generation algorithm (without using an API) described in the :py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA
`paper <https://arxiv.org/abs/2007.09127>`__ with a Wav2Vec2 model. implementations which are more performant than the vanilla Python
implementation above, and are more accurate.
It can also handle missing transcript with special <star> token.
For examples of aligning multiple languages, please refer to
`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__.
""" """
import torch import torch
import torchaudio import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
######################################################################
#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from dataclasses import dataclass
print(device) from typing import List
try: import IPython
from torchaudio.functional import forced_align import matplotlib.pyplot as plt
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
###################################################################### ######################################################################
# Basic usages
# ------------
#
# In this section, we cover the following content:
#
# 1. Generate frame-wise class probabilites from audio waveform from a CTC
# acoustic model.
# 2. Compute frame-level alignments using TorchAudio’s forced alignment
# API.
# 3. Obtain token-level alignments from frame-level alignments.
# 4. Obtain word-level alignments from token-level alignments.
# #
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
###################################################################### ######################################################################
# Preparation # Preparation
# ~~~~~~~~~~~ # -----------
# #
# First we import the necessary packages, and fetch data that we work on. # First we prepare the speech data and the transcript we area going
# to use.
# #
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
sample_rate = 16000 TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
###################################################################### ######################################################################
# Generate frame-wise class posteriors from a CTC acoustic model # Generating emissions and tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`~torchaudio.functional.forced_align` takes emission and
# token sequences and outputs timestaps of the tokens and their scores.
# #
# The first step is to generate the class probabilities (i.e. posteriors) # Emission reperesents the frame-wise probability distribution over
# of each audio frame using a CTC model. # tokens, and it can be obtained by passing waveform to an acoustic
# Here we use :py:func:`torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`. # model.
# Tokens are numerical expression of transcripts. It can be obtained by
# simply mapping each character to the index of token list.
# The emission and the token sequences must be using the same set of tokens.
#
# We can use pre-trained Wav2Vec2 model to obtain emission from speech,
# and map transcript to tokens.
# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`,
# which bandles pre-trained model weights with associated labels.
# #
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device) model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode(): with torch.inference_mode():
waveform, _ = torchaudio.load(SPEECH_FILE) waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device)) emission, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1) emission = torch.log_softmax(emission, dim=-1)
######################################################################
#
emission = emissions.cpu().detach() def plot_emission(emission):
dictionary = {c: i for i, c in enumerate(labels)} plt.imshow(emission.cpu().T)
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.tight_layout()
print(dictionary)
plot_emission(emission[0])
###################################################################### ######################################################################
# Visualization # We create a dictionary, which maps each label into token.
# ^^^^^^^^^^^^^
# labels = bundle.get_labels()
DICTIONARY = {c: i for i, c in enumerate(labels)}
for k, v in DICTIONARY.items():
print(f"{k}: {v}")
######################################################################
# converting transcript to tokens is as simple as
plt.imshow(emission[0].T) tokens = [DICTIONARY[c] for c in TRANSCRIPT]
plt.colorbar()
plt.title("Frame-wise class probabilities")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
print(" ".join(str(t) for t in tokens))
###################################################################### ######################################################################
# Computing frame-level alignments # Computing frame-level alignments
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # --------------------------------
# #
# Then we call TorchAudio’s forced alignment API to compute the # Now we call TorchAudio’s forced alignment API to compute the
# frame-level alignment between each audio frame and each token in the # frame-level alignment. For the detail of function signature, please
# transcript. We first explain the inputs and outputs of the API # refer to :py:func:`~torchaudio.functional.forced_align`.
# ``functional.forced_align``. Note that this API works on both CPU and
# GPU. In the current tutorial we demonstrate it on CPU.
# #
# **Inputs**:
# #
# ``emission``: a 2D tensor of size :math:`T \times N`, where :math:`T` is
# the number of frames (after sub-sampling by the acoustic model, if any),
# and :math:`N` is the vocabulary size. def align(emission, tokens):
# alignments, scores = forced_align(
# ``targets``: a 1D tensor vector of size :math:`M`, where :math:`M` is emission,
# the length of the transcript, and each element is a token ID looked up targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device),
# from the vocabulary. For example, the ``targets`` tensor repsenting the input_lengths=torch.tensor([emission.size(1)], device=emission.device),
# transcript “i had…” is :math:`[5, 18, 4, 16, ...]`. target_lengths=torch.tensor([len(tokens)], device=emission.device),
# blank=0,
# ``input lengths``: :math:`T`. )
#
# ``target lengths``: :math:`M`. scores = scores.exp() # convert back to probability
# alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity
# **Outputs**: return alignments.tolist(), scores.tolist()
#
# ``frame_alignment``: a 1D tensor of size :math:`T` storing the aligned
# token index (looked up from the vocabulary) of each frame, e.g. for the frame_alignment, frame_scores = align(emission, tokens)
# segment corresponding to “i had” in the given example , the
# frame_alignment is ######################################################################
# :math:`[...0, 0, 5, 0, 0, 18, 18, 4, 0, 0, 0, 16,...]`, where :math:`0` # Now let's look at the output.
# represents the blank symbol. # Notice that the alignment is expressed in the frame cordinate of
# emission, which is different from the original waveform.
for i, (ali, score) in enumerate(zip(frame_alignment, frame_scores)):
print(f"{i:3d}: {ali:2d} [{labels[ali]}], {score:.2f}")
######################################################################
# #
# ``frame_scores``: a 1D tensor of size :math:`T` storing the confidence # The ``Frame`` instance represents the most likely token at each frame
# score (0 to 1) for each each frame. For each frame, the score should be # with its confidence.
# close to one if the alignment quality is good. #
# When interpreting it, one must remember that the meaning of blank token
# and repeated token are context dependent.
#
# .. note::
#
# When same token occured after blank tokens, it is not treated as
# a repeat, but as a new occurrence.
#
# .. code-block::
#
# a a a b -> a b
# a - - b -> a b
# a a - b -> a b
# a - a b -> a a b
# ^^^ ^^^
#
# .. code-block::
#
# 29: 0 [-], 1.00
# 30: 7 [I], 1.00 # Start of "I"
# 31: 0 [-], 0.98 # repeat (blank token)
# 32: 0 [-], 1.00 # repeat (blank token)
# 33: 1 [|], 0.85 # Start of "|" (word boundary)
# 34: 1 [|], 1.00 # repeat (same token)
# 35: 0 [-], 0.61 # repeat (blank token)
# 36: 8 [H], 1.00 # Start of "H"
# 37: 0 [-], 1.00 # repeat (blank token)
# 38: 4 [A], 1.00 # Start of "A"
# 39: 0 [-], 0.99 # repeat (blank token)
# 40: 11 [D], 0.92 # Start of "D"
# 41: 0 [-], 0.93 # repeat (blank token)
# 42: 1 [|], 0.98 # Start of "|"
# 43: 1 [|], 1.00 # repeat (same token)
# 44: 3 [T], 1.00 # Start of "T"
# 45: 3 [T], 0.90 # repeat (same token)
# 46: 8 [H], 1.00 # Start of "H"
# 47: 0 [-], 1.00 # repeat (blank token)
######################################################################
# Resolve blank and repeated tokens
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# From the outputs ``frame_alignment`` and ``frame_scores``, we generate a # Next step is to resolve the repetation. So that what alignment represents
# do not depend on previous alignments.
# From the outputs ``alignment`` and ``scores``, we generate a
# list called ``frames`` storing information of all frames aligned to # list called ``frames`` storing information of all frames aligned to
# non-blank tokens. Each element contains 1) token_index: the aligned # non-blank tokens.
# token’s index in the transcript 2) time_index: the current frame’s index
# in the input audio (or more precisely, the row dimension of the emission
# matrix) 3) the confidence scores of the current frame.
#
# For the given example, the first few elements of the list ``frames``
# corresponding to “i had” looks as the following:
#
# ``Frame(token_index=0, time_index=32, score=0.9994410872459412)``
# #
# ``Frame(token_index=1, time_index=35, score=0.9980823993682861)`` # Each element contains the following
# #
# ``Frame(token_index=1, time_index=36, score=0.9295750260353088)`` # - ``token_index``: the aligned token’s index **in the transcript**
# # - ``time_index``: the current frame’s index in emission
# ``Frame(token_index=2, time_index=37, score=0.9997448325157166)`` # - ``score``: scores of the current frame.
#
# ``Frame(token_index=3, time_index=41, score=0.9991760849952698)``
#
# ``...``
#
# The interpretation is:
#
# The token with index :math:`0` in the transcript, i.e. “i”, is aligned
# to the :math:`32`\ th audio frame, with confidence :math:`0.9994`. The
# token with index :math:`1` in the transcript, i.e. “h”, is aligned to
# the :math:`35`\ th and :math:`36`\ th audio frames, with confidence
# :math:`0.9981` and :math:`0.9296` respectively. The token with index
# :math:`2` in the transcript, i.e. “a”, is aligned to the :math:`35`\ th
# and :math:`36`\ th audio frames, with confidence :math:`0.9997`. The
# token with index :math:`3` in the transcript, i.e. “d”, is aligned to
# the :math:`41`\ th audio frame, with confidence :math:`0.9992`.
#
# From such information stored in the ``frames`` list, we’ll compute
# token-level and word-level alignments easily.
# #
# ``token_index`` is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
@dataclass @dataclass
class Frame: class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
def compute_alignments(transcript, dictionary, emission): ######################################################################
frames = [] #
tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0)
input_lengths = torch.tensor([emission.shape[1]])
target_lengths = torch.tensor([targets.shape[1]])
# This is the key step, where we call the forced alignment API functional.forced_align to compute alignments.
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
assert frame_alignment.shape[1] == input_lengths[0].item() def obtain_token_level_alignments(alignments, scores) -> List[Frame]:
assert targets.shape[1] == target_lengths[0].item() assert len(alignments) == len(scores)
token_index = -1 token_index = -1
prev_hyp = 0 prev_hyp = 0
for i in range(frame_alignment.shape[1]): frames = []
if frame_alignment[0][i].item() == 0: for i, (ali, score) in enumerate(zip(alignments, scores)):
if ali == 0:
prev_hyp = 0 prev_hyp = 0
continue continue
if frame_alignment[0][i].item() != prev_hyp: if ali != prev_hyp:
token_index += 1 token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item())) frames.append(Frame(token_index, i, score))
prev_hyp = frame_alignment[0][i].item() prev_hyp = ali
return frames, frame_alignment, frame_scores return frames
transcript = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT"
frames, frame_alignment, frame_scores = compute_alignments(transcript, dictionary, emission)
###################################################################### ######################################################################
# Obtain token-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
frames = obtain_token_level_alignments(frame_alignment, frame_scores)
print("Time\tLabel\tScore")
for f in frames:
print(f"{f.time_index:3d}\t{TRANSCRIPT[f.token_index]}\t{f.score:.2f}")
###################################################################### ######################################################################
# Obtain token-level alignments and confidence scores
# ---------------------------------------------------
#
# The frame-level alignments contains repetations for the same labels. # The frame-level alignments contains repetations for the same labels.
# Another format “token-level alignment”, which specifies the aligned # Another format “token-level alignment”, which specifies the aligned
# frame ranges for each transcript token, contains the same information, # frame ranges for each transcript token, contains the same information,
# while being more convenient to apply to some downstream tasks # while being more convenient to apply to some downstream tasks
# (e.g. computing word-level alignments). # (e.g. computing word-level alignments).
# #
# Now we demonstrate how to obtain token-level alignments and confidence # Now we demonstrate how to obtain token-level alignments and confidence
# scores by simply merging frame-level alignments and averaging # scores by simply merging frame-level alignments and averaging
# frame-level confidence scores. # frame-level confidence scores.
# #
######################################################################
# The following class represents the label, its score and the time span
# of its occurance.
#
# Merge the labels
@dataclass @dataclass
class Segment: class Segment:
label: str label: str
...@@ -261,13 +284,16 @@ class Segment: ...@@ -261,13 +284,16 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property def __len__(self):
def length(self):
return self.end - self.start return self.end - self.start
######################################################################
#
def merge_repeats(frames, transcript): def merge_repeats(frames, transcript):
transcript_nospace = transcript.replace(" ", "") transcript_nospace = transcript.replace(" ", "")
i1, i2 = 0, 0 i1, i2 = 0, 0
...@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript): ...@@ -288,29 +314,31 @@ def merge_repeats(frames, transcript):
return segments return segments
segments = merge_repeats(frames, transcript) ######################################################################
#
segments = merge_repeats(frames, TRANSCRIPT)
for seg in segments: for seg in segments:
print(seg) print(seg)
###################################################################### ######################################################################
# Visualization # Visualization
# ^^^^^^^^^^^^^ # ~~~~~~~~~~~~~
# #
def plot_label_prob(segments, transcript): def plot_label_prob(segments, transcript):
fig, ax2 = plt.subplots(figsize=(16, 4)) fig, ax = plt.subplots()
ax2.set_title("frame-level and token-level confidence scores") ax.set_title("frame-level and token-level confidence scores")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
xs.append((seg.end + seg.start) / 2 + 0.4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") ax.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") ax.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
for p in frames: for p in frames:
...@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript): ...@@ -319,27 +347,28 @@ def plot_label_prob(segments, transcript):
xs.append(p.time_index + 1) xs.append(p.time_index + 1)
hs.append(p.score) hs.append(p.score)
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black") ax.set_ylim(-0.1, 1.1)
ax2.set_ylim(-0.1, 1.1) ax.grid(True, axis="y")
fig.tight_layout()
plot_label_prob(segments, transcript) plot_label_prob(segments, TRANSCRIPT)
plt.tight_layout()
plt.show()
###################################################################### ######################################################################
# From the visualized scores, we can see that, for tokens spanning over # From the visualized scores, we can see that, for tokens spanning over
# more multiple frames, e.g. “T” in “THAT, the token-level confidence # more multiple frames, e.g. “T” in “THAT, the token-level confidence
# score is the average of frame-level confidence scores. To make this # score is the average of frame-level confidence scores. To make this
# clearer, we don’t plot confidence scores for blank frames, which was # clearer, we don’t plot confidence scores for blank frames, which was
# plotted in the”Label probability with and without repeatation” figure in # plotted in the”Label probability with and without repeatation” figure in
# the previous tutorial `“Forced Alignment with # the previous tutorial
# Wav2Vec2 <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__. # `Forced Alignment with Wav2Vec2 <./forced_alignment_tutorial.html>`__.
# #
######################################################################
# Obtain word-level alignments and confidence scores # Obtain word-level alignments and confidence scores
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # --------------------------------------------------
# #
...@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "): ...@@ -367,7 +396,7 @@ def merge_words(transcript, segments, separator=" "):
s = 0 s = 0
segs = segments[i1 + s : i2 + s] segs = segments[i1 + s : i2 + s]
word = "".join([seg.label for seg in segs]) word = "".join([seg.label for seg in segs])
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) score = sum(seg.score * len(seg) for seg in segs) / sum(len(seg) for seg in segs)
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score)) words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2 i1 = i2
else: else:
...@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "): ...@@ -376,59 +405,43 @@ def merge_words(transcript, segments, separator=" "):
return words return words
word_segments = merge_words(transcript, segments, "|") word_segments = merge_words(TRANSCRIPT, segments, "|")
###################################################################### ######################################################################
# Visualization # Visualization
# ^^^^^^^^^^^^^ # ~~~~~~~~~~~~~
# #
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10): def plot_alignments(waveform, emission, segments, word_segments, sample_rate=bundle.sample_rate):
fig, ax2 = plt.subplots(figsize=(64, 12)) fig, ax = plt.subplots()
plt.rcParams.update({"font.size": 30})
# The original waveform ax.specgram(waveform[0], Fs=sample_rate)
ratio = waveform.size(1) / input_lengths
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale)
ax2.set_xlim(0, waveform.size(-1))
# The original waveform
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments: for word in word_segments:
x0 = ratio * word.start t0, t1 = ratio * word.start, ratio * word.end
x1 = ratio * word.end ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax2.axvspan(x0, x1, alpha=0.1, color="red") ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale)) ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks() ax.set_xlabel("time [second]")
plt.xticks(xticks, xticks / sample_rate, fontsize=50) fig.tight_layout()
ax2.set_xlabel("time [second]", fontsize=40)
ax2.set_yticks([])
plot_alignments( plot_alignments(waveform, emission, segments, word_segments)
segments,
word_segments,
waveform,
emission.shape[1],
1,
)
plt.show()
###################################################################### ######################################################################
# A trick to embed the resulting audio to the generated file. def display_segment(i, waveform, word_segments, frame_alignment, sample_rate=bundle.sample_rate):
# `IPython.display.Audio` has to be the last call in a cell, ratio = waveform.size(1) / len(frame_alignment)
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, frame_alignment):
ratio = waveform.size(1) / frame_alignment.size(1)
word = word_segments[i] word = word_segments[i]
x0 = int(ratio * word.start) x0 = int(ratio * word.start)
x1 = int(ratio * word.end) x1 = int(ratio * word.end)
...@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment): ...@@ -437,8 +450,10 @@ def display_segment(i, waveform, word_segments, frame_alignment):
return IPython.display.Audio(segment.numpy(), rate=sample_rate) return IPython.display.Audio(segment.numpy(), rate=sample_rate)
######################################################################
# Generate the audio for each segment # Generate the audio for each segment
print(transcript) print(TRANSCRIPT)
IPython.display.Audio(SPEECH_FILE) IPython.display.Audio(SPEECH_FILE)
###################################################################### ######################################################################
...@@ -488,62 +503,71 @@ display_segment(8, waveform, word_segments, frame_alignment) ...@@ -488,62 +503,71 @@ display_segment(8, waveform, word_segments, frame_alignment)
###################################################################### ######################################################################
# Advanced usage: Dealing with missing transcripts using the <star> token # Advanced: Handling transcripts with ``<star>`` token
# --------------------------------------------------------------------------- # ----------------------------------------------------
# #
# Now let’s look at when the transcript is partially missing, how can we # Now let’s look at when the transcript is partially missing, how can we
# improve alignment quality using the <star> token, which is capable of modeling # improve alignment quality using the ``<star>`` token, which is capable of modeling
# any token. # any token.
# #
# Here we use the same English example as used above. But we remove the # Here we use the same English example as used above. But we remove the
# beginning text “i had that curiosity beside me at” from the transcript. # beginning text ``“i had that curiosity beside me at”`` from the transcript.
# Aligning audio with such transcript results in wrong alignments of the # Aligning audio with such transcript results in wrong alignments of the
# existing word “this”. However, this issue can be mitigated by using the # existing word “this”. However, this issue can be mitigated by using the
# <star> token to model the missing text. # ``<star>`` token to model the missing text.
# #
# Reload the emission tensor in order to add the extra dimension corresponding to the <star> token. ######################################################################
with torch.inference_mode(): # First, we extend the dictionary to include the ``<star>`` token.
waveform, _ = torchaudio.load(SPEECH_FILE)
emissions, _ = model(waveform.to(device))
emissions = torch.log_softmax(emissions, dim=-1)
# Append the extra dimension corresponding to the <star> token DICTIONARY["*"] = len(DICTIONARY)
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
# Extend the dictionary to include the <star> token. ######################################################################
dictionary["*"] = 29 # Next, we extend the emission tensor with the extra dimension
# corresponding to the ``<star>`` token.
#
assert len(dictionary) == emission.shape[2] extra_dim = torch.zeros(emission.shape[0], emission.shape[1], 1, device=device)
emission = torch.cat((emission, extra_dim), 2)
assert len(DICTIONARY) == emission.shape[2]
######################################################################
# The following function combines all the processes, and compute
# word segments from emission in one-go.
def compute_and_plot_alignments(transcript, dictionary, emission, waveform): def compute_and_plot_alignments(transcript, dictionary, emission, waveform):
frames, frame_alignment, _ = compute_alignments(transcript, dictionary, emission) tokens = [dictionary[c] for c in transcript]
alignment, scores = align(emission, tokens)
frames = obtain_token_level_alignments(alignment, scores)
segments = merge_repeats(frames, transcript) segments = merge_repeats(frames, transcript)
word_segments = merge_words(transcript, segments, "|") word_segments = merge_words(transcript, segments, "|")
plot_alignments(segments, word_segments, waveform, emission.shape[1], 1) plot_alignments(waveform, emission, segments, word_segments)
plt.show() plt.xlim([0, None])
return word_segments, frame_alignment
# original:
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
###################################################################### ######################################################################
# **Original**
# Demonstrate the effect of <star> token for dealing with deletion errors compute_and_plot_alignments(TRANSCRIPT, DICTIONARY, emission, waveform)
# ("i had that curiosity beside me at" missing from the transcript):
transcript = "THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform)
###################################################################### ######################################################################
# **With <star> token**
#
# Now we replace the first part of the transcript with the ``<star>`` token.
# Replacing the missing transcript with the <star> token: compute_and_plot_alignments("*|THIS|MOMENT", DICTIONARY, emission, waveform)
transcript = "*|THIS|MOMENT"
word_segments, frame_alignment = compute_and_plot_alignments(transcript, dictionary, emission, waveform) ######################################################################
# **Without <star> token**
#
# As a comparison, the following aligns the partial transcript
# without using ``<star>`` token.
# It demonstrates the effect of ``<star>`` token for dealing with deletion errors.
compute_and_plot_alignments("THIS|MOMENT", DICTIONARY, emission, waveform)
###################################################################### ######################################################################
# Conclusion # Conclusion
...@@ -551,7 +575,7 @@ word_segments, frame_alignment = compute_and_plot_alignments(transcript, diction ...@@ -551,7 +575,7 @@ word_segments, frame_alignment = compute_and_plot_alignments(transcript, diction
# #
# In this tutorial, we looked at how to use torchaudio’s forced alignment # In this tutorial, we looked at how to use torchaudio’s forced alignment
# API to align and segment speech files, and demonstrated one advanced usage: # API to align and segment speech files, and demonstrated one advanced usage:
# How introducing a <star> token could improve alignment accuracy when # How introducing a ``<star>`` token could improve alignment accuracy when
# transcription errors exist. # transcription errors exist.
# #
......
...@@ -69,7 +69,7 @@ import torchvision ...@@ -69,7 +69,7 @@ import torchvision
# ------------------- # -------------------
# #
# Firstly, we define the function to collect videos from microphone and # Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader` # camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing # class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this # audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the # class, please refer to the
......
...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape num_filts, window_size = irs.shape
half = window_size // 2 half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size) t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors): for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}") ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)]) axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors): for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True) ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}") ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n" "Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048): ...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True) axes[i].grid(True)
axes[1].set(title="Frequency Response") axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency") axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right") axes[2].legend(loc="center right")
fig.tight_layout() fig.tight_layout()
......
...@@ -6,15 +6,14 @@ Forced alignment for multilingual data ...@@ -6,15 +6,14 @@ Forced alignment for multilingual data
This tutorial shows how to compute forced alignments for speech data This tutorial shows how to compute forced alignments for speech data
from multiple non-English languages using ``torchaudio``'s CTC forced alignment from multiple non-English languages using ``torchaudio``'s CTC forced alignment
API described in `“CTC forced alignment API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__
tutorial” <https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html>`__ and the multilingual Wav2vec2 model proposed in the paper `Scaling
and the multilingual Wav2vec2 model proposed in the paper `“Scaling
Speech Technology to 1,000+ Speech Technology to 1,000+
Languages” <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__. Languages <https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/>`__.
The model was trained on 23K of audio data from 1100+ languages using The model was trained on 23K of audio data from 1100+ languages using
the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__ the `uroman vocabulary <https://www.isi.edu/~ulf/uroman.html>`__
as targets. as targets.
""" """
import torch import torch
...@@ -23,53 +22,46 @@ import torchaudio ...@@ -23,53 +22,46 @@ import torchaudio
print(torch.__version__) print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) print(device)
try:
from torchaudio.functional import forced_align
except ModuleNotFoundError:
print(
"Failed to import the forced alignment API. "
"Please install torchaudio nightly builds. "
"Please refer to https://pytorch.org/get-started/locally "
"for instructions to install a nightly build."
)
raise
###################################################################### ######################################################################
# Preparation # Preparation
# ----------- # -----------
# #
# Here we import necessary packages, and define utility functions for
# computing the frame-level alignments (using the API
# ``functional.forced_align``), token-level and word-level alignments, and
# also alignment visualization utilities.
#
# %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.functional import forced_align
torch.random.manual_seed(0)
sample_rate = 16000 ######################################################################
#
SAMPLE_RATE = 16000
######################################################################
#
# Here we define utility functions for computing the frame-level
# alignments (using the API :py:func:`torchaudio.functional.forced_align`),
# token-level and word-level alignments.
# For the detail of these functions please refer to
# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__.
#
@dataclass @dataclass
class Frame: class Frame:
# This is the index of each token in the transcript,
# i.e. the current frame aligns to the N-th character from the transcript.
token_index: int token_index: int
time_index: int time_index: int
score: float score: float
######################################################################
#
@dataclass @dataclass
class Segment: class Segment:
label: str label: str
...@@ -78,39 +70,42 @@ class Segment: ...@@ -78,39 +70,42 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label:2s} ({self.score:4.2f}): [{self.start:4d}, {self.end:4d})"
@property def __len__(self):
def length(self):
return self.end - self.start return self.end - self.start
# compute frame-level and word-level alignments using torchaudio's forced alignment API ######################################################################
#
def compute_alignments(transcript, dictionary, emission): def compute_alignments(transcript, dictionary, emission):
frames = []
tokens = [dictionary[c] for c in transcript.replace(" ", "")] tokens = [dictionary[c] for c in transcript.replace(" ", "")]
targets = torch.tensor(tokens, dtype=torch.int32).unsqueeze(0) targets = torch.tensor([tokens], dtype=torch.int32, device=emission.device)
input_lengths = torch.tensor([emission.shape[1]]) input_lengths = torch.tensor([emission.shape[1]], device=emission.device)
target_lengths = torch.tensor([targets.shape[1]]) target_lengths = torch.tensor([targets.shape[1]], device=emission.device)
alignment, scores = forced_align(emission, targets, input_lengths, target_lengths, 0)
# This is the key step, where we call the forced alignment API functional.forced_align to compute frame alignments. scores = scores.exp() # convert back to probability
frame_alignment, frame_scores = forced_align(emission, targets, input_lengths, target_lengths, 0) alignment, scores = alignment[0].tolist(), scores[0].tolist()
assert frame_alignment.shape[1] == input_lengths[0].item() assert len(alignment) == len(scores) == emission.size(1)
assert targets.shape[1] == target_lengths[0].item()
token_index = -1 token_index = -1
prev_hyp = 0 prev_hyp = 0
for i in range(frame_alignment.shape[1]): frames = []
if frame_alignment[0][i].item() == 0: for i, (ali, score) in enumerate(zip(alignment, scores)):
if ali == 0:
prev_hyp = 0 prev_hyp = 0
continue continue
if frame_alignment[0][i].item() != prev_hyp: if ali != prev_hyp:
token_index += 1 token_index += 1
frames.append(Frame(token_index, i, frame_scores[0][i].exp().item())) frames.append(Frame(token_index, i, score))
prev_hyp = frame_alignment[0][i].item() prev_hyp = ali
# compute frame alignments from token alignments # compute frame alignments from token alignments
transcript_nospace = transcript.replace(" ", "") transcript_nospace = transcript.replace(" ", "")
...@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission): ...@@ -140,52 +135,59 @@ def compute_alignments(transcript, dictionary, emission):
if i1 != i2: if i1 != i2:
if i3 == len(transcript) - 1: if i3 == len(transcript) - 1:
i2 += 1 i2 += 1
s = 0 segs = segments[i1:i2]
segs = segments[i1 + s : i2 + s] word = "".join([s.label for s in segs])
word = "".join([seg.label for seg in segs]) score = sum(s.score * len(s) for s in segs) / sum(len(s) for s in segs)
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segs[0].start, segs[-1].end + 1, score))
words.append(Segment(word, segments[i1 + s].start, segments[i2 + s - 1].end, score))
i1 = i2 i1 = i2
else: else:
i2 += 1 i2 += 1
i3 += 1 i3 += 1
return segments, words
num_frames = frame_alignment.shape[1]
return segments, words, num_frames
######################################################################
#
def plot_emission(emission):
fig, ax = plt.subplots()
ax.imshow(emission.T, aspect="auto")
ax.set_title("Emission")
fig.tight_layout()
# utility function for plotting word alignments
def plot_alignments(segments, word_segments, waveform, input_lengths, scale=10):
fig, ax2 = plt.subplots(figsize=(64, 12))
plt.rcParams.update({"font.size": 30})
# The original waveform ######################################################################
ratio = waveform.size(1) / input_lengths #
ax2.plot(waveform)
ax2.set_ylim(-1.0 * scale, 1.0 * scale) # utility function for plotting word alignments
ax2.set_xlim(0, waveform.size(-1)) def plot_alignments(waveform, emission, segments, word_segments, sample_rate=SAMPLE_RATE):
fig, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
xlim = ax.get_xlim()
ratio = waveform.size(1) / sample_rate / emission.size(1)
for word in word_segments: for word in word_segments:
x0 = ratio * word.start t0, t1 = word.start * ratio, word.end * ratio
x1 = ratio * word.end ax.axvspan(t0, t1, facecolor="None", hatch="/", edgecolor="white")
ax2.axvspan(x0, x1, alpha=0.1, color="red") ax.annotate(f"{word.score:.2f}", (t0, sample_rate * 0.51), annotation_clip=False)
ax2.annotate(f"{word.score:.2f}", (x0, 0.8 * scale))
for seg in segments: for seg in segments:
if seg.label != "|": if seg.label != "|":
ax2.annotate(seg.label, (seg.start * ratio, 0.9 * scale)) ax.annotate(seg.label, (seg.start * ratio, sample_rate * 0.53), annotation_clip=False)
xticks = ax2.get_xticks() ax.set_xlabel("time [second]")
plt.xticks(xticks, xticks / sample_rate, fontsize=50) ax.set_xlim(xlim)
ax2.set_xlabel("time [second]", fontsize=40) fig.tight_layout()
ax2.set_yticks([])
return IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
# utility function for playing audio segments. # utility function for playing audio segments.
# A trick to embed the resulting audio to the generated file. def display_segment(i, waveform, word_segments, num_frames, sample_rate=SAMPLE_RATE):
# `IPython.display.Audio` has to be the last call in a cell,
# and there should be only one call par cell.
def display_segment(i, waveform, word_segments, num_frames):
ratio = waveform.size(1) / num_frames ratio = waveform.size(1) / num_frames
word = word_segments[i] word = word_segments[i]
x0 = int(ratio * word.start) x0 = int(ratio * word.start)
...@@ -241,26 +243,21 @@ model.load_state_dict( ...@@ -241,26 +243,21 @@ model.load_state_dict(
) )
) )
model.eval() model.eval()
model.to(device)
def get_emission(waveform): def get_emission(waveform):
# NOTE: this step is essential with torch.inference_mode():
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) # NOTE: this step is essential
waveform = torch.nn.functional.layer_norm(waveform, waveform.shape)
emissions, _ = model(waveform) emission, _ = model(waveform)
emissions = torch.log_softmax(emissions, dim=-1) return torch.log_softmax(emission, dim=-1)
emission = emissions.cpu().detach()
# Append the extra dimension corresponding to the <star> token
extra_dim = torch.zeros(emissions.shape[0], emissions.shape[1], 1)
emissions = torch.cat((emissions.cpu(), extra_dim), 2)
emission = emissions.detach()
return emission, waveform
# Construct the dictionary # Construct the dictionary
# '@' represents the OOV token, '*' represents the <star> token. # '@' represents the OOV token
# <pad> and </s> are fairseq's legacy tokens, which're not used. # <pad> and </s> are fairseq's legacy tokens, which're not used.
# <star> token is omitted as we do not use it in this tutorial
dictionary = { dictionary = {
"<blank>": 0, "<blank>": 0,
"<pad>": 1, "<pad>": 1,
...@@ -293,7 +290,6 @@ dictionary = { ...@@ -293,7 +290,6 @@ dictionary = {
"'": 28, "'": 28,
"q": 29, "q": 29,
"x": 30, "x": 30,
"*": 31,
} }
...@@ -304,11 +300,8 @@ dictionary = { ...@@ -304,11 +300,8 @@ dictionary = {
# romanizer and using it to obtain romanized transcripts, and PyThon # romanizer and using it to obtain romanized transcripts, and PyThon
# commands required for further normalizing the romanized transcript. # commands required for further normalizing the romanized transcript.
# #
# %%
# .. code-block:: bash # .. code-block:: bash
# #
# %%bash
# Save the raw transcript to a file # Save the raw transcript to a file
# echo 'raw text' > text.txt # echo 'raw text' > text.txt
# git clone https://github.com/isi-nlp/uroman # git clone https://github.com/isi-nlp/uroman
...@@ -334,141 +327,77 @@ dictionary = { ...@@ -334,141 +327,77 @@ dictionary = {
###################################################################### ######################################################################
# German example: # German
# ~~~~~~~~~~~~~~~~ # ~~~~~~
text_raw = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
text_normalized = (
"aber seit ich bei ihnen das brot hole brauch ich viel weniger schulze wandte sich ab die kinder taten ihm leid"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False) speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
assert len(dictionary) == emission.shape[2]
transcript = text_normalized text_raw = "aber seit ich bei ihnen das brot hole"
text_normalized = "aber seit ich bei ihnen das brot hole"
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized) print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(1, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
###################################################################### plot_alignments(waveform, emission, segments, word_segments)
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(5, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(7, waveform, word_segments, num_frames)
######################################################################
#
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(13, waveform, word_segments, num_frames) display_segment(1, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(14, waveform, word_segments, num_frames) display_segment(2, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(15, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(16, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(17, waveform, word_segments, num_frames) display_segment(4, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(18, waveform, word_segments, num_frames) display_segment(5, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(19, waveform, word_segments, num_frames) display_segment(6, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(20, waveform, word_segments, num_frames) display_segment(7, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# Chinese example: # Chinese
# ~~~~~~~~~~~~~~~~ # ~~~~~~~
# #
# Chinese is a character-based language, and there is not explicit word-level # Chinese is a character-based language, and there is not explicit word-level
# tokenization (separated by spaces) in its raw written form. In order to # tokenization (separated by spaces) in its raw written form. In order to
...@@ -478,98 +407,36 @@ display_segment(20, waveform, word_segments, num_frames) ...@@ -478,98 +407,36 @@ display_segment(20, waveform, word_segments, num_frames)
# However this is not needed if you only want character-level alignments. # However this is not needed if you only want character-level alignments.
# #
text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False) speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False)
waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面"
plot_alignments(segments, word_segments, waveform, emission.shape[1]) text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian"
print("Raw Transcript: ", text_raw) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized) print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file)
waveform = waveform[0:1]
######################################################################
#
display_segment(1, waveform, word_segments, num_frames)
######################################################################
#
display_segment(2, waveform, word_segments, num_frames)
######################################################################
#
display_segment(3, waveform, word_segments, num_frames)
######################################################################
#
display_segment(4, waveform, word_segments, num_frames)
######################################################################
#
display_segment(5, waveform, word_segments, num_frames)
######################################################################
#
display_segment(6, waveform, word_segments, num_frames)
######################################################################
#
display_segment(7, waveform, word_segments, num_frames) emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(8, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
######################################################################
# Polish example:
# ~~~~~~~~~~~~~~~
text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę dlaczego mi nie powiedziałeś szepnąłem ze łzami"
text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane dlaczego mi nie powiedziales szepnalem ze lzami"
speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission) plot_alignments(waveform, emission, segments, word_segments)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -585,7 +452,6 @@ display_segment(2, waveform, word_segments, num_frames) ...@@ -585,7 +452,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -611,68 +477,40 @@ display_segment(7, waveform, word_segments, num_frames) ...@@ -611,68 +477,40 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames) display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# # Polish
# ~~~~~~
display_segment(10, waveform, word_segments, num_frames) speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False)
###################################################################### text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę"
# text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane"
display_segment(11, waveform, word_segments, num_frames) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
###################################################################### ######################################################################
# #
display_segment(12, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE))
######################################################################
#
display_segment(13, waveform, word_segments, num_frames) emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(14, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
# Portuguese example:
# ~~~~~~~~~~~~~~~~~~~
text_raw = (
"mas na imensa extensão onde se esconde o inconsciente imortal só me responde um bramido um queixume e nada mais"
)
text_normalized = (
"mas na imensa extensao onde se esconde o inconsciente imortal so me responde um bramido um queixume e nada mais"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -688,7 +526,6 @@ display_segment(2, waveform, word_segments, num_frames) ...@@ -688,7 +526,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -710,94 +547,38 @@ display_segment(6, waveform, word_segments, num_frames) ...@@ -710,94 +547,38 @@ display_segment(6, waveform, word_segments, num_frames)
display_segment(7, waveform, word_segments, num_frames) display_segment(7, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# # Portuguese
# ~~~~~~~~~~
display_segment(8, waveform, word_segments, num_frames)
######################################################################
#
display_segment(9, waveform, word_segments, num_frames)
######################################################################
#
display_segment(10, waveform, word_segments, num_frames)
######################################################################
#
display_segment(11, waveform, word_segments, num_frames)
######################################################################
#
display_segment(12, waveform, word_segments, num_frames)
###################################################################### speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False)
#
display_segment(13, waveform, word_segments, num_frames)
######################################################################
#
display_segment(14, waveform, word_segments, num_frames)
######################################################################
#
display_segment(15, waveform, word_segments, num_frames)
###################################################################### text_raw = "na imensa extensão onde se esconde o inconsciente imortal"
# text_normalized = "na imensa extensao onde se esconde o inconsciente imortal"
display_segment(16, waveform, word_segments, num_frames) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
###################################################################### ######################################################################
# #
display_segment(17, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE))
###################################################################### emission = get_emission(waveform.to(device))
# num_frames = emission.size(1)
plot_emission(emission[0].cpu())
display_segment(18, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(19, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
######################################################################
# Italian example:
# ~~~~~~~~~~~~~~~~
text_raw = "elle giacean per terra tutte quante fuor d'una ch'a seder si levò ratto ch'ella ci vide passarsi davante"
text_normalized = (
"elle giacean per terra tutte quante fuor d'una ch'a seder si levo ratto ch'ella ci vide passarsi davante"
)
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
waveform, _ = torchaudio.load(speech_file)
emission, waveform = get_emission(waveform)
transcript = text_normalized
segments, word_segments, num_frames = compute_alignments(transcript, dictionary, emission)
plot_alignments(segments, word_segments, waveform, emission.shape[1])
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
IPython.display.Audio(waveform, rate=sample_rate)
###################################################################### ######################################################################
# #
display_segment(0, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -813,7 +594,6 @@ display_segment(2, waveform, word_segments, num_frames) ...@@ -813,7 +594,6 @@ display_segment(2, waveform, word_segments, num_frames)
display_segment(3, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
...@@ -840,50 +620,62 @@ display_segment(7, waveform, word_segments, num_frames) ...@@ -840,50 +620,62 @@ display_segment(7, waveform, word_segments, num_frames)
display_segment(8, waveform, word_segments, num_frames) display_segment(8, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# # Italian
# ~~~~~~~
speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False)
text_raw = "elle giacean per terra tutte quante"
text_normalized = "elle giacean per terra tutte quante"
display_segment(9, waveform, word_segments, num_frames) print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
###################################################################### ######################################################################
# #
display_segment(10, waveform, word_segments, num_frames) waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE))
emission = get_emission(waveform.to(device))
num_frames = emission.size(1)
plot_emission(emission[0].cpu())
###################################################################### ######################################################################
# #
display_segment(11, waveform, word_segments, num_frames) segments, word_segments = compute_alignments(text_normalized, dictionary, emission)
plot_alignments(waveform, emission, segments, word_segments)
###################################################################### ######################################################################
# #
display_segment(12, waveform, word_segments, num_frames) display_segment(0, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(13, waveform, word_segments, num_frames) display_segment(1, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(14, waveform, word_segments, num_frames) display_segment(2, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(15, waveform, word_segments, num_frames) display_segment(3, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(16, waveform, word_segments, num_frames) display_segment(4, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# #
display_segment(17, waveform, word_segments, num_frames) display_segment(5, waveform, word_segments, num_frames)
###################################################################### ######################################################################
# Conclusion # Conclusion
...@@ -894,7 +686,6 @@ display_segment(17, waveform, word_segments, num_frames) ...@@ -894,7 +686,6 @@ display_segment(17, waveform, word_segments, num_frames)
# speech data to transcripts in five languages. # speech data to transcripts in five languages.
# #
###################################################################### ######################################################################
# Acknowledgement # Acknowledgement
# --------------- # ---------------
......
...@@ -56,16 +56,11 @@ print(device) ...@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on. # First we import the necessary packages, and fetch data that we work on.
# #
# %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -99,17 +94,22 @@ with torch.inference_mode(): ...@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
print(labels)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
###################################################################### ######################################################################
# Generate alignment probability (trellis) # Generate alignment probability (trellis)
...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens) ...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) def plot():
plt.colorbar() plt.imshow(trellis.T, origin="lower")
plt.show() plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
###################################################################### ######################################################################
# In the above visualization, we can see that there is a trace of high # In the above visualization, we can see that there is a trace of high
...@@ -266,7 +271,9 @@ for p in path: ...@@ -266,7 +271,9 @@ for p in path:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path): def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path): ...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path) plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking") plt.title("The path found by backtracking")
plt.show()
###################################################################### ######################################################################
# Looking good. Now this path contains repetations for the same labels, so # Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript. # let’s merge them to make it close to the original transcript.
# #
# When merging the multiple path points, we simply take the average # When merging the multiple path points, we simply take the average
...@@ -297,7 +308,7 @@ class Segment: ...@@ -297,7 +308,7 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property @property
def length(self): def length(self):
...@@ -330,7 +341,9 @@ for seg in segments: ...@@ -330,7 +341,9 @@ for seg in segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript): def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label") ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold") ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3)) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation") ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black") ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim()) ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1) ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript) plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
###################################################################### ######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` # Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of # as the word boundary, so we merge the segments before each occurance of
# ``'|'``. # ``'|'``.
# #
...@@ -410,16 +426,16 @@ for word in word_segments: ...@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform): def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([]) ax1.set_xticks([])
ax1.set_yticks([]) ax1.set_yticks([])
...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7)) ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform # The original waveform
ratio = waveform.size(0) / trellis.size(0) ratio = waveform.size(0) / trellis.size(0)
...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([]) ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0) ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1)) ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments( plot_alignments(
...@@ -458,7 +475,6 @@ plot_alignments( ...@@ -458,7 +475,6 @@ plot_alignments(
word_segments, word_segments,
waveform[0], waveform[0],
) )
plt.show()
################################################################################ ################################################################################
......
...@@ -162,11 +162,10 @@ def separate_sources( ...@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) _, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto") axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor ...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:", "SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(), separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
) )
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}") plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate) return Audio(predicted_source, rate=sample_rate)
...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu() ...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
# #
# Mixture Clip # Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture") plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate) Audio(mix_spec, rate=sample_rate)
###################################################################### ######################################################################
......
...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav") ...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
# #
def plot_spectrogram(stft, title="Spectrogram", xlim=None): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None): def plot_mask(mask, title="Mask"):
mask = mask.numpy() mask = mask.numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8): def si_snr(estimate, reference, epsilon=1e-8):
......
...@@ -33,12 +33,9 @@ print(torchaudio.__version__) ...@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os import os
import time import time
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.io import StreamReader from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
###################################################################### ######################################################################
# #
# Check the prerequisites # Check the prerequisites
......
...@@ -160,8 +160,7 @@ for i, feats in enumerate(features): ...@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}") ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension") ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)") ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout() fig.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest") ...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result") plt.title("Classification result")
plt.xlabel("Frame (time-axis)") plt.xlabel("Frame (time-axis)")
plt.ylabel("Class") plt.ylabel("Class")
plt.show() plt.tight_layout()
print("Class labels:", bundle.get_labels()) print("Class labels:", bundle.get_labels())
......
...@@ -82,19 +82,23 @@ try: ...@@ -82,19 +82,23 @@ try:
from pystoi import stoi from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError: except ImportError:
import google.colab # noqa: F401 try:
import google.colab # noqa: F401
print(
""" print(
To enable running this notebook in Google Colab, install nightly """
torch and torchaudio builds by adding the following code block to the top To enable running this notebook in Google Colab, install nightly
of the notebook before running it: torch and torchaudio builds by adding the following code block to the top
!pip3 uninstall -y torch torchvision torchaudio of the notebook before running it:
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu !pip3 uninstall -y torch torchvision torchaudio
!pip3 install pesq !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install pystoi !pip3 install pesq
""" !pip3 install pystoi
) """
)
except Exception:
pass
raise
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8): ...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item() return si_snr.item()
def plot_waveform(waveform, title): def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy() wav_numpy = waveform.numpy()
sample_size = waveform.shape[1] sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000 time_axis = torch.arange(0, sample_size) / sample_rate
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.plot(time_axis, wav_numpy[0], linewidth=1)
axes.grid(True)
figure.suptitle(title)
plt.show(block=False)
figure, axes = plt.subplots(2, 1)
def plot_specgram(waveform, sample_rate, title): axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
wav_numpy = waveform.numpy() axes[0].grid(True)
figure, axes = plt.subplots(1, 1) axes[1].specgram(wav_numpy[0], Fs=sample_rate)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000) ...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample # Visualize speech sample
# #
plot_waveform(WAVEFORM_SPEECH, "Clean Speech") plot(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
###################################################################### ######################################################################
# Visualize noise sample # Visualize noise sample
# #
plot_waveform(WAVEFORM_NOISE, "Noise") plot(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
###################################################################### ######################################################################
# Visualize distorted speech with 20dB SNR # Visualize distorted speech with 20dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR") plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
###################################################################### ######################################################################
# Visualize distorted speech with -5dB SNR # Visualize distorted speech with -5dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR") plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
###################################################################### ######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment