Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
......@@ -127,6 +127,22 @@ def _get_pattern():
return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = {
"examples_dirs": [
"../../examples/tutorials",
......@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True,
"first_notebook_cell": None,
"doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
}
autosummary_generate = True
......
......@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial
tutorials/device_asr
tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial
......@@ -147,6 +147,13 @@ Tutorials
.. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem::
:header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
......@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
#
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True)
......@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4):
axes[i].grid(True)
pos = axes[2].get_position()
plt.tight_layout()
fig.tight_layout()
if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
......@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
#
......@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Square wave
......@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Triangle wave
......@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
#
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
######################################################################
# Inharmonic Paritials
......@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
######################################################################
#
......@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE)
plot(freq, amp, waveform, SAMPLE_RATE)
######################################################################
# References
......
......@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
#
def plot_alignments(waveform, emission, tokens, timesteps):
fig, ax = plt.subplots(figsize=(32, 10))
ax.plot(waveform)
ratio = waveform.shape[0] / emission.shape[1]
word_start = 0
for i in range(len(tokens)):
if i != 0 and tokens[i - 1] == "|":
word_start = timesteps[i]
if tokens[i] != "|":
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14)
elif i != 0:
word_end = timesteps[i]
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red")
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0])
plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
t = torch.arange(waveform.size(0)) / sample_rate
ratio = waveform.size(0) / emission.size(1) / sample_rate
chars = []
words = []
word_start = None
for token, timestep in zip(tokens, timesteps * ratio):
if token == "|":
if word_start is not None:
words.append((word_start, timestep))
word_start = None
else:
chars.append((token, timestep))
if word_start is None:
word_start = timestep
fig, axes = plt.subplots(3, 1)
def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
######################################################################
......
......@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......
# -*- coding: utf-8 -*-
"""
Audio Datasets
==============
......@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets.
"""
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch
import torchaudio
......@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os
import IPython
import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
......@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
figure.tight_layout()
######################################################################
......@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset.
#
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]:
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
play_audio(waveform, sample_rate)
######################################################################
#
i = 1
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
......@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# Preparation
# -----------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
......@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_spec(ax, spec, title, ylabel="freq_bin"):
ax.set_title(title)
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
######################################################################
......@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None)
stretch = T.TimeStretch()
rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
spec_12 = stretch(spec, overriding_rate=1.2)
spec_09 = stretch(spec, overriding_rate=0.9)
######################################################################
# TimeMasking
# -----------
#
torch.random.manual_seed(4)
spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original")
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis")
plot()
######################################################################
# FrequencyMasking
# ----------------
# Time and Frequency Masking
# --------------------------
#
torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original")
time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis")
plot()
......@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True)
ax.set_xlim([0, time_axis[-1]])
ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
......@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title)
ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None):
......@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin")
plt.show(block=False)
######################################################################
......@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
......@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform")
plt.show(block=False)
######################################################################
......@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title)
plt.show(block=False)
######################################################################
......
......@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax)
plt.show(block=True)
######################################################################
......
......@@ -69,7 +69,7 @@ import torchvision
# -------------------
#
# Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader`
# camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the
......
......@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape
half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)"
)
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout()
fig.tight_layout()
######################################################################
......@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5))
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
......@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)"
)
plt.tight_layout()
fig.tight_layout()
######################################################################
......@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True)
axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right")
axes[2].legend(loc="center right")
fig.tight_layout()
......
......@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on.
#
# %matplotlib inline
from dataclasses import dataclass
import IPython
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
......@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach()
print(labels)
################################################################################
# Visualization
################################################################################
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
# ~~~~~~~~~~~~~
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
######################################################################
# Generate alignment probability (trellis)
......@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################
# Visualization
################################################################################
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plt.show()
# ~~~~~~~~~~~~~
def plot():
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
######################################################################
# In the above visualization, we can see that there is a trace of high
......@@ -266,7 +271,9 @@ for p in path:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
......@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking")
plt.show()
######################################################################
# Looking good. Now this path contains repetations for the same labels, so
# Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript.
#
# When merging the multiple path points, we simply take the average
......@@ -297,7 +308,7 @@ class Segment:
score: float
def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property
def length(self):
......@@ -330,7 +341,9 @@ for seg in segments:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone()
......@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.set_xticks([])
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3))
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], []
......@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score)
ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold")
ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], []
......@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim())
ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of
# ``'|'``.
#
......@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################
# Visualization
################################################################################
# ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone()
for i, seg in enumerate(segments):
if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5))
fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower")
ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
ax1.set_yticks([])
......@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments):
if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7))
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8)
ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform
ratio = waveform.size(0) / trellis.size(0)
......@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments(
......@@ -458,7 +475,6 @@ plot_alignments(
word_segments,
waveform[0],
)
plt.show()
################################################################################
......
......@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
_, axis = plt.subplots(1, 1)
axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
axis.set_title(title)
plt.tight_layout()
######################################################################
......@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
)
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}")
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate)
......@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
#
# Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture")
plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate)
######################################################################
......
......@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
#
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
axis.set_title(title)
plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None):
def plot_mask(mask, title="Mask"):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title)
axis.set_title(title)
plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8):
......
......@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os
import time
import matplotlib
import matplotlib.pyplot as plt
from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
######################################################################
#
# Check the prerequisites
......
......@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout()
plt.show()
fig.tight_layout()
######################################################################
......@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result")
plt.xlabel("Frame (time-axis)")
plt.ylabel("Class")
plt.show()
plt.tight_layout()
print("Class labels:", bundle.get_labels())
......
......@@ -82,19 +82,23 @@ try:
from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError:
import google.colab # noqa: F401
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install pesq
!pip3 install pystoi
"""
)
try:
import google.colab # noqa: F401
print(
"""
To enable running this notebook in Google Colab, install nightly
torch and torchaudio builds by adding the following code block to the top
of the notebook before running it:
!pip3 uninstall -y torch torchvision torchaudio
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install pesq
!pip3 install pystoi
"""
)
except Exception:
pass
raise
import matplotlib.pyplot as plt
......@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item()
def plot_waveform(waveform, title):
def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy()
sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.plot(time_axis, wav_numpy[0], linewidth=1)
axes.grid(True)
figure.suptitle(title)
plt.show(block=False)
time_axis = torch.arange(0, sample_size) / sample_rate
def plot_specgram(waveform, sample_rate, title):
wav_numpy = waveform.numpy()
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure, axes = plt.subplots(2, 1)
axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
axes[0].grid(True)
axes[1].specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
######################################################################
......@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample
#
plot_waveform(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
plot(WAVEFORM_SPEECH, "Clean Speech")
######################################################################
# Visualize noise sample
#
plot_waveform(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
plot(WAVEFORM_NOISE, "Noise")
######################################################################
# Visualize distorted speech with 20dB SNR
#
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
######################################################################
# Visualize distorted speech with -5dB SNR
#
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment