Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
...@@ -127,6 +127,22 @@ def _get_pattern(): ...@@ -127,6 +127,22 @@ def _get_pattern():
return ret return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = { sphinx_gallery_conf = {
"examples_dirs": [ "examples_dirs": [
"../../examples/tutorials", "../../examples/tutorials",
...@@ -139,6 +155,7 @@ sphinx_gallery_conf = { ...@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True, "promote_jupyter_magic": True,
"first_notebook_cell": None, "first_notebook_cell": None,
"doc_module": ("torchaudio",), "doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
} }
autosummary_generate = True autosummary_generate = True
......
...@@ -71,8 +71,8 @@ model implementations and application components. ...@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial tutorials/online_asr_tutorial
tutorials/device_asr tutorials/device_asr
tutorials/device_avsr tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial tutorials/hybrid_demucs_tutorial
...@@ -147,6 +147,13 @@ Tutorials ...@@ -147,6 +147,13 @@ Tutorials
.. customcardstart:: .. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem:: .. customcarditem::
:header: Loading waveform Tensors from files and saving them :header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions. :card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE) ...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
# #
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True) fig, axes = plt.subplots(4, 1, sharex=True)
...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): ...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4): for i in range(4):
axes[i].grid(True) axes[i].grid(True)
pos = axes[2].get_position() pos = axes[2].get_position()
plt.tight_layout() fig.tight_layout()
if zoom is not None: if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0]) ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# #
...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES) ...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1) freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Square wave # Square wave
...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Triangle wave # Triangle wave
...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
# #
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Inharmonic Paritials # Inharmonic Paritials
...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1) ...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
###################################################################### ######################################################################
# #
...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) ...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones) freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE) plot(freq, amp, waveform, SAMPLE_RATE)
###################################################################### ######################################################################
# References # References
......
...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0]) ...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
# #
def plot_alignments(waveform, emission, tokens, timesteps): def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
fig, ax = plt.subplots(figsize=(32, 10))
t = torch.arange(waveform.size(0)) / sample_rate
ax.plot(waveform) ratio = waveform.size(0) / emission.size(1) / sample_rate
ratio = waveform.shape[0] / emission.shape[1] chars = []
word_start = 0 words = []
word_start = None
for i in range(len(tokens)): for token, timestep in zip(tokens, timesteps * ratio):
if i != 0 and tokens[i - 1] == "|": if token == "|":
word_start = timesteps[i] if word_start is not None:
if tokens[i] != "|": words.append((word_start, timestep))
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) word_start = None
elif i != 0: else:
word_end = timesteps[i] chars.append((token, timestep))
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") if word_start is None:
word_start = timestep
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate) fig, axes = plt.subplots(3, 1)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0]) def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
plot_alignments(waveform[0], emission, predicted_tokens, timesteps) ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
###################################################################### ######################################################################
......
...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None): ...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): ...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
# -*- coding: utf-8 -*-
""" """
Audio Datasets Audio Datasets
============== ==============
...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of ...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets. available datasets.
""" """
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch import torch
import torchaudio import torchaudio
...@@ -21,22 +16,13 @@ print(torch.__version__) ...@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os import os
import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no") ...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True) os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, _ = waveform.shape figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) figure.tight_layout()
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
###################################################################### ######################################################################
...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate): ...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset. # :py:class:`torchaudio.datasets.YESNO` dataset.
# #
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True) dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]: ######################################################################
waveform, sample_rate, label = dataset[i] #
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") i = 1
play_audio(waveform, sample_rate) waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
...@@ -19,25 +19,19 @@ print(torch.__version__) ...@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section) # Preparation
# -------------------------------------------------------- # -----------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -75,16 +69,9 @@ def get_spectrogram( ...@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform) return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): def plot_spec(ax, spec, title, ylabel="freq_bin"):
fig, axs = plt.subplots(1, 1) ax.set_title(title)
axs.set_title(title or "Spectrogram (db)") ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No ...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None) spec = get_spectrogram(power=None)
stretch = T.TimeStretch() stretch = T.TimeStretch()
rate = 1.2 spec_12 = stretch(spec, overriding_rate=1.2)
spec_ = stretch(spec, rate) spec_09 = stretch(spec, overriding_rate=0.9)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
###################################################################### ######################################################################
# TimeMasking
# -----------
# #
torch.random.manual_seed(4)
spec = get_spectrogram() def plot():
plot_spectrogram(spec[0], title="Original") fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis") plot()
###################################################################### ######################################################################
# FrequencyMasking # Time and Frequency Masking
# ---------------- # --------------------------
# #
torch.random.manual_seed(4) torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram() spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original") time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis") plot()
...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None): ...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True) ax.grid(True)
ax.set_xlim([0, time_axis[-1]]) ax.set_xlim([0, time_axis[-1]])
ax.set_title(title) ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): ...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title) ax.set_title(title)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None): def plot_fbank(fbank, title=None):
...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None): ...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto") axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin") axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin") axs.set_xlabel("mel bin")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch): ...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0) axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate): ...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform") figure.suptitle("waveform")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"): ...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
...@@ -105,7 +105,6 @@ def plot_sweep( ...@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67) axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax) plt.colorbar(cax)
plt.show(block=True)
###################################################################### ######################################################################
......
...@@ -69,7 +69,7 @@ import torchvision ...@@ -69,7 +69,7 @@ import torchvision
# ------------------- # -------------------
# #
# Firstly, we define the function to collect videos from microphone and # Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader` # camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing # class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this # audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the # class, please refer to the
......
...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape num_filts, window_size = irs.shape
half = window_size // 2 half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size) t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors): for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}") ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)]) axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors): for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True) ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}") ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n" "Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048): ...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True) axes[i].grid(True)
axes[1].set(title="Frequency Response") axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency") axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right") axes[2].legend(loc="center right")
fig.tight_layout() fig.tight_layout()
......
...@@ -56,16 +56,11 @@ print(device) ...@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on. # First we import the necessary packages, and fetch data that we work on.
# #
# %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -99,17 +94,22 @@ with torch.inference_mode(): ...@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
print(labels)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
###################################################################### ######################################################################
# Generate alignment probability (trellis) # Generate alignment probability (trellis)
...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens) ...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) def plot():
plt.colorbar() plt.imshow(trellis.T, origin="lower")
plt.show() plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
###################################################################### ######################################################################
# In the above visualization, we can see that there is a trace of high # In the above visualization, we can see that there is a trace of high
...@@ -266,7 +271,9 @@ for p in path: ...@@ -266,7 +271,9 @@ for p in path:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path): def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path): ...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path) plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking") plt.title("The path found by backtracking")
plt.show()
###################################################################### ######################################################################
# Looking good. Now this path contains repetations for the same labels, so # Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript. # let’s merge them to make it close to the original transcript.
# #
# When merging the multiple path points, we simply take the average # When merging the multiple path points, we simply take the average
...@@ -297,7 +308,7 @@ class Segment: ...@@ -297,7 +308,7 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property @property
def length(self): def length(self):
...@@ -330,7 +341,9 @@ for seg in segments: ...@@ -330,7 +341,9 @@ for seg in segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript): def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label") ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold") ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3)) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation") ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black") ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim()) ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1) ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript) plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
###################################################################### ######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` # Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of # as the word boundary, so we merge the segments before each occurance of
# ``'|'``. # ``'|'``.
# #
...@@ -410,16 +426,16 @@ for word in word_segments: ...@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform): def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([]) ax1.set_xticks([])
ax1.set_yticks([]) ax1.set_yticks([])
...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7)) ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform # The original waveform
ratio = waveform.size(0) / trellis.size(0) ratio = waveform.size(0) / trellis.size(0)
...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([]) ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0) ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1)) ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments( plot_alignments(
...@@ -458,7 +475,6 @@ plot_alignments( ...@@ -458,7 +475,6 @@ plot_alignments(
word_segments, word_segments,
waveform[0], waveform[0],
) )
plt.show()
################################################################################ ################################################################################
......
...@@ -162,11 +162,10 @@ def separate_sources( ...@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) _, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto") axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor ...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:", "SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(), separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
) )
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}") plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate) return Audio(predicted_source, rate=sample_rate)
...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu() ...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
# #
# Mixture Clip # Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture") plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate) Audio(mix_spec, rate=sample_rate)
###################################################################### ######################################################################
......
...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav") ...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
# #
def plot_spectrogram(stft, title="Spectrogram", xlim=None): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None): def plot_mask(mask, title="Mask"):
mask = mask.numpy() mask = mask.numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8): def si_snr(estimate, reference, epsilon=1e-8):
......
...@@ -33,12 +33,9 @@ print(torchaudio.__version__) ...@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os import os
import time import time
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.io import StreamReader from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
###################################################################### ######################################################################
# #
# Check the prerequisites # Check the prerequisites
......
...@@ -160,8 +160,7 @@ for i, feats in enumerate(features): ...@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}") ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension") ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)") ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout() fig.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest") ...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result") plt.title("Classification result")
plt.xlabel("Frame (time-axis)") plt.xlabel("Frame (time-axis)")
plt.ylabel("Class") plt.ylabel("Class")
plt.show() plt.tight_layout()
print("Class labels:", bundle.get_labels()) print("Class labels:", bundle.get_labels())
......
...@@ -82,19 +82,23 @@ try: ...@@ -82,19 +82,23 @@ try:
from pystoi import stoi from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError: except ImportError:
import google.colab # noqa: F401 try:
import google.colab # noqa: F401
print(
""" print(
To enable running this notebook in Google Colab, install nightly """
torch and torchaudio builds by adding the following code block to the top To enable running this notebook in Google Colab, install nightly
of the notebook before running it: torch and torchaudio builds by adding the following code block to the top
!pip3 uninstall -y torch torchvision torchaudio of the notebook before running it:
!pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu !pip3 uninstall -y torch torchvision torchaudio
!pip3 install pesq !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
!pip3 install pystoi !pip3 install pesq
""" !pip3 install pystoi
) """
)
except Exception:
pass
raise
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8): ...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item() return si_snr.item()
def plot_waveform(waveform, title): def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy() wav_numpy = waveform.numpy()
sample_size = waveform.shape[1] sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000 time_axis = torch.arange(0, sample_size) / sample_rate
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.plot(time_axis, wav_numpy[0], linewidth=1)
axes.grid(True)
figure.suptitle(title)
plt.show(block=False)
figure, axes = plt.subplots(2, 1)
def plot_specgram(waveform, sample_rate, title): axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
wav_numpy = waveform.numpy() axes[0].grid(True)
figure, axes = plt.subplots(1, 1) axes[1].specgram(wav_numpy[0], Fs=sample_rate)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000) ...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample # Visualize speech sample
# #
plot_waveform(WAVEFORM_SPEECH, "Clean Speech") plot(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
###################################################################### ######################################################################
# Visualize noise sample # Visualize noise sample
# #
plot_waveform(WAVEFORM_NOISE, "Noise") plot(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
###################################################################### ######################################################################
# Visualize distorted speech with 20dB SNR # Visualize distorted speech with 20dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR") plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
###################################################################### ######################################################################
# Visualize distorted speech with -5dB SNR # Visualize distorted speech with -5dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR") plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
###################################################################### ######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment