Commit 84b12306 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Set and tweak global matplotlib configuration in tutorials (#3515)

Summary:
- Set global matplotlib rc params
- Fix style check
- Fix and updates FA tutorial plots
- Add av-asr index cars

Pull Request resolved: https://github.com/pytorch/audio/pull/3515

Reviewed By: huangruizhe

Differential Revision: D47894156

Pulled By: mthrok

fbshipit-source-id: b40d8d31f12ffc2b337e35e632afc216e9d59a6e
parent 8497ee91
...@@ -127,6 +127,22 @@ def _get_pattern(): ...@@ -127,6 +127,22 @@ def _get_pattern():
return ret return ret
def reset_mpl(gallery_conf, fname):
from sphinx_gallery.scrapers import _reset_matplotlib
_reset_matplotlib(gallery_conf, fname)
import matplotlib
matplotlib.rcParams.update(
{
"image.interpolation": "none",
"figure.figsize": (9.6, 4.8),
"font.size": 8.0,
"axes.axisbelow": True,
}
)
sphinx_gallery_conf = { sphinx_gallery_conf = {
"examples_dirs": [ "examples_dirs": [
"../../examples/tutorials", "../../examples/tutorials",
...@@ -139,6 +155,7 @@ sphinx_gallery_conf = { ...@@ -139,6 +155,7 @@ sphinx_gallery_conf = {
"promote_jupyter_magic": True, "promote_jupyter_magic": True,
"first_notebook_cell": None, "first_notebook_cell": None,
"doc_module": ("torchaudio",), "doc_module": ("torchaudio",),
"reset_modules": (reset_mpl, "seaborn"),
} }
autosummary_generate = True autosummary_generate = True
......
...@@ -71,8 +71,8 @@ model implementations and application components. ...@@ -71,8 +71,8 @@ model implementations and application components.
tutorials/online_asr_tutorial tutorials/online_asr_tutorial
tutorials/device_asr tutorials/device_asr
tutorials/device_avsr tutorials/device_avsr
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/forced_alignment_tutorial tutorials/forced_alignment_tutorial
tutorials/forced_alignment_for_multilingual_data_tutorial
tutorials/tacotron2_pipeline_tutorial tutorials/tacotron2_pipeline_tutorial
tutorials/mvdr_tutorial tutorials/mvdr_tutorial
tutorials/hybrid_demucs_tutorial tutorials/hybrid_demucs_tutorial
...@@ -147,6 +147,13 @@ Tutorials ...@@ -147,6 +147,13 @@ Tutorials
.. customcardstart:: .. customcardstart::
.. customcarditem::
:header: On device audio-visual automatic speech recognition
:card_description: Learn how to stream audio and video from laptop webcam and perform audio-visual automatic speech recognition using Emformer-RNNT model.
:image: https://download.pytorch.org/torchaudio/doc-assets/avsr/transformed.gif
:link: tutorials/device_avsr.html
:tags: I/O,Pipelines,RNNT
.. customcarditem:: .. customcarditem::
:header: Loading waveform Tensors from files and saving them :header: Loading waveform Tensors from files and saving them
:card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions. :card_description: Learn how to query/load audio files and save waveform tensors to files, using <code>torchaudio.info</code>, <code>torchaudio.load</code> and <code>torchaudio.save</code> functions.
......
...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE) ...@@ -85,7 +85,7 @@ NUM_FRAMES = int(DURATION * SAMPLE_RATE)
# #
def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): def plot(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
t = torch.arange(waveform.size(0)) / sample_rate t = torch.arange(waveform.size(0)) / sample_rate
fig, axes = plt.subplots(4, 1, sharex=True) fig, axes = plt.subplots(4, 1, sharex=True)
...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1): ...@@ -101,7 +101,7 @@ def show(freq, amp, waveform, sample_rate, zoom=None, vol=0.1):
for i in range(4): for i in range(4):
axes[i].grid(True) axes[i].grid(True)
pos = axes[2].get_position() pos = axes[2].get_position()
plt.tight_layout() fig.tight_layout()
if zoom is not None: if zoom is not None:
ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0]) ax = fig.add_axes([pos.x0 + 0.02, pos.y0 + 0.03, pos.width / 2.5, pos.height / 2.0])
...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -168,7 +168,7 @@ def sawtooth_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# #
...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES) ...@@ -183,7 +183,7 @@ phase = torch.linspace(0, fm * PI2 * DURATION, NUM_FRAMES)
freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1) freq0 = F0 + f_dev * torch.sin(phase).unsqueeze(-1)
freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE) freq, amp, waveform = sawtooth_wave(freq0, amp0, int(SAMPLE_RATE / F0), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Square wave # Square wave
...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -220,7 +220,7 @@ def square_wave(freq0, amp0, num_pitches, sample_rate):
freq0 = torch.full((NUM_FRAMES, 1), F0) freq0 = torch.full((NUM_FRAMES, 1), F0)
amp0 = torch.ones((NUM_FRAMES, 1)) amp0 = torch.ones((NUM_FRAMES, 1))
freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = square_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Triangle wave # Triangle wave
...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate): ...@@ -256,7 +256,7 @@ def triangle_wave(freq0, amp0, num_pitches, sample_rate):
# #
freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE) freq, amp, waveform = triangle_wave(freq0, amp0, int(SAMPLE_RATE / F0 / 2), SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0)) plot(freq, amp, waveform, SAMPLE_RATE, zoom=(1 / F0, 3 / F0))
###################################################################### ######################################################################
# Inharmonic Paritials # Inharmonic Paritials
...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1) ...@@ -296,7 +296,7 @@ amp = torch.stack([amp * (0.5**i) for i in range(num_tones)], dim=-1)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) plot(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
###################################################################### ######################################################################
# #
...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4) ...@@ -308,7 +308,7 @@ show(freq, amp, waveform, SAMPLE_RATE, vol=0.4)
freq = extend_pitch(freq0, num_tones) freq = extend_pitch(freq0, num_tones)
waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE) waveform = oscillator_bank(freq, amp, sample_rate=SAMPLE_RATE)
show(freq, amp, waveform, SAMPLE_RATE) plot(freq, amp, waveform, SAMPLE_RATE)
###################################################################### ######################################################################
# References # References
......
...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0]) ...@@ -407,30 +407,45 @@ print(timesteps, timesteps.shape[0])
# #
def plot_alignments(waveform, emission, tokens, timesteps): def plot_alignments(waveform, emission, tokens, timesteps, sample_rate):
fig, ax = plt.subplots(figsize=(32, 10))
t = torch.arange(waveform.size(0)) / sample_rate
ax.plot(waveform) ratio = waveform.size(0) / emission.size(1) / sample_rate
ratio = waveform.shape[0] / emission.shape[1] chars = []
word_start = 0 words = []
word_start = None
for i in range(len(tokens)): for token, timestep in zip(tokens, timesteps * ratio):
if i != 0 and tokens[i - 1] == "|": if token == "|":
word_start = timesteps[i] if word_start is not None:
if tokens[i] != "|": words.append((word_start, timestep))
plt.annotate(tokens[i].upper(), (timesteps[i] * ratio, waveform.max() * 1.02), size=14) word_start = None
elif i != 0: else:
word_end = timesteps[i] chars.append((token, timestep))
ax.axvspan(word_start * ratio, word_end * ratio, alpha=0.1, color="red") if word_start is None:
word_start = timestep
xticks = ax.get_xticks()
plt.xticks(xticks, xticks / bundle.sample_rate) fig, axes = plt.subplots(3, 1)
ax.set_xlabel("time (sec)")
ax.set_xlim(0, waveform.shape[0]) def _plot(ax, xlim):
ax.plot(t, waveform)
for token, timestep in chars:
plot_alignments(waveform[0], emission, predicted_tokens, timesteps) ax.annotate(token.upper(), (timestep, 0.5))
for word_start, word_end in words:
ax.axvspan(word_start, word_end, alpha=0.1, color="red")
ax.set_ylim(-0.6, 0.7)
ax.set_yticks([0])
ax.grid(True, axis="y")
ax.set_xlim(xlim)
_plot(axes[0], (0.3, 2.5))
_plot(axes[1], (2.5, 4.7))
_plot(axes[2], (4.7, 6.9))
axes[2].set_xlabel("time (sec)")
fig.tight_layout()
plot_alignments(waveform[0], emission, predicted_tokens, timesteps, bundle.sample_rate)
###################################################################### ######################################################################
......
...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None): ...@@ -100,7 +100,6 @@ def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): ...@@ -122,7 +121,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
if xlim: if xlim:
axes[c].set_xlim(xlim) axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
# -*- coding: utf-8 -*-
""" """
Audio Datasets Audio Datasets
============== ==============
...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of ...@@ -10,10 +9,6 @@ datasets. Please refer to the official documentation for the list of
available datasets. available datasets.
""" """
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio
import torch import torch
import torchaudio import torchaudio
...@@ -21,22 +16,13 @@ print(torch.__version__) ...@@ -21,22 +16,13 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os import os
import IPython
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets" _SAMPLE_DIR = "_assets"
...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no") ...@@ -44,34 +30,13 @@ YESNO_DATASET_PATH = os.path.join(_SAMPLE_DIR, "yes_no")
os.makedirs(YESNO_DATASET_PATH, exist_ok=True) os.makedirs(YESNO_DATASET_PATH, exist_ok=True)
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): def plot_specgram(waveform, sample_rate, title="Spectrogram"):
waveform = waveform.numpy() waveform = waveform.numpy()
num_channels, _ = waveform.shape figure, ax = plt.subplots()
ax.specgram(waveform[0], Fs=sample_rate)
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False) figure.tight_layout()
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, _ = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
###################################################################### ######################################################################
...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate): ...@@ -79,10 +44,25 @@ def play_audio(waveform, sample_rate):
# :py:class:`torchaudio.datasets.YESNO` dataset. # :py:class:`torchaudio.datasets.YESNO` dataset.
# #
dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True) dataset = torchaudio.datasets.YESNO(YESNO_DATASET_PATH, download=True)
for i in [1, 3, 5]: ######################################################################
waveform, sample_rate, label = dataset[i] #
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}") i = 1
play_audio(waveform, sample_rate) waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 3
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
######################################################################
#
i = 5
waveform, sample_rate, label = dataset[i]
plot_specgram(waveform, sample_rate, title=f"Sample {i}: {label}")
IPython.display.Audio(waveform, rate=sample_rate)
...@@ -19,25 +19,19 @@ print(torch.__version__) ...@@ -19,25 +19,19 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section) # Preparation
# -------------------------------------------------------- # -----------
# #
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
######################################################################
# In this tutorial, we will use a speech data from
# `VOiCES dataset <https://iqtlabs.github.io/voices/>`__,
# which is licensed under Creative Commos BY 4.0.
SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SAMPLE_WAV_SPEECH_PATH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -75,16 +69,9 @@ def get_spectrogram( ...@@ -75,16 +69,9 @@ def get_spectrogram(
return spectrogram(waveform) return spectrogram(waveform)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): def plot_spec(ax, spec, title, ylabel="freq_bin"):
fig, axs = plt.subplots(1, 1) ax.set_title(title)
axs.set_title(title or "Spectrogram (db)") ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No ...@@ -108,43 +95,47 @@ def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=No
spec = get_spectrogram(power=None) spec = get_spectrogram(power=None)
stretch = T.TimeStretch() stretch = T.TimeStretch()
rate = 1.2 spec_12 = stretch(spec, overriding_rate=1.2)
spec_ = stretch(spec, rate) spec_09 = stretch(spec, overriding_rate=0.9)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
plot_spectrogram(torch.abs(spec[0]), title="Original", aspect="equal", xmax=304)
rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect="equal", xmax=304)
###################################################################### ######################################################################
# TimeMasking
# -----------
# #
torch.random.manual_seed(4)
spec = get_spectrogram() def plot():
plot_spectrogram(spec[0], title="Original") fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
plot_spec(axes[2], torch.abs(spec_09[0]), title="Stretched x0.9")
fig.tight_layout()
masking = T.TimeMasking(time_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along time axis") plot()
###################################################################### ######################################################################
# FrequencyMasking # Time and Frequency Masking
# ---------------- # --------------------------
# #
torch.random.manual_seed(4) torch.random.manual_seed(4)
time_masking = T.TimeMasking(time_mask_param=80)
freq_masking = T.FrequencyMasking(freq_mask_param=80)
spec = get_spectrogram() spec = get_spectrogram()
plot_spectrogram(spec[0], title="Original") time_masked = time_masking(spec)
freq_masked = freq_masking(spec)
######################################################################
#
def plot():
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
plot_spec(axes[0], spec[0], title="Original")
plot_spec(axes[1], time_masked[0], title="Masked along time axis")
plot_spec(axes[2], freq_masked[0], title="Masked along frequency axis")
fig.tight_layout()
masking = T.FrequencyMasking(freq_mask_param=80)
spec = masking(spec)
plot_spectrogram(spec[0], title="Masked along frequency axis") plot()
...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None): ...@@ -75,7 +75,6 @@ def plot_waveform(waveform, sr, title="Waveform", ax=None):
ax.grid(True) ax.grid(True)
ax.set_xlim([0, time_axis[-1]]) ax.set_xlim([0, time_axis[-1]])
ax.set_title(title) ax.set_title(title)
plt.show(block=False)
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): ...@@ -85,7 +84,6 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
ax.set_title(title) ax.set_title(title)
ax.set_ylabel(ylabel) ax.set_ylabel(ylabel)
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest")
plt.show(block=False)
def plot_fbank(fbank, title=None): def plot_fbank(fbank, title=None):
...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None): ...@@ -94,7 +92,6 @@ def plot_fbank(fbank, title=None):
axs.imshow(fbank, aspect="auto") axs.imshow(fbank, aspect="auto")
axs.set_ylabel("frequency bin") axs.set_ylabel("frequency bin")
axs.set_xlabel("mel bin") axs.set_xlabel("mel bin")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch): ...@@ -486,7 +483,6 @@ def plot_pitch(waveform, sr, pitch):
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0) axis2.legend(loc=0)
plt.show(block=False)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate): ...@@ -181,7 +181,6 @@ def plot_waveform(waveform, sample_rate):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle("waveform") figure.suptitle("waveform")
plt.show(block=False)
###################################################################### ######################################################################
...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"): ...@@ -204,7 +203,6 @@ def plot_specgram(waveform, sample_rate, title="Spectrogram"):
if num_channels > 1: if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}") axes[c].set_ylabel(f"Channel {c+1}")
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
......
...@@ -105,7 +105,6 @@ def plot_sweep( ...@@ -105,7 +105,6 @@ def plot_sweep(
axis.yaxis.grid(True, alpha=0.67) axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax) plt.colorbar(cax)
plt.show(block=True)
###################################################################### ######################################################################
......
...@@ -69,7 +69,7 @@ import torchvision ...@@ -69,7 +69,7 @@ import torchvision
# ------------------- # -------------------
# #
# Firstly, we define the function to collect videos from microphone and # Firstly, we define the function to collect videos from microphone and
# camera. To be specific, we use :py:func:`~torchaudio.io.StreamReader` # camera. To be specific, we use :py:class:`~torchaudio.io.StreamReader`
# class for the purpose of data collection, which supports capturing # class for the purpose of data collection, which supports capturing
# audio/video from microphone and camera. For the detailed usage of this # audio/video from microphone and camera. For the detailed usage of this
# class, please refer to the # class, please refer to the
......
...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -89,7 +89,7 @@ def plot_sinc_ir(irs, cutoff):
num_filts, window_size = irs.shape num_filts, window_size = irs.shape
half = window_size // 2 half = window_size // 2
fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, figsize=(9.6, 8))
t = torch.linspace(-half, half - 1, window_size) t = torch.linspace(-half, half - 1, window_size)
for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors): for ax, ir, coff, color in zip(axes, irs, cutoff, plt.cm.tab10.colors):
ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}") ax.plot(t, ir, linewidth=1.2, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff): ...@@ -100,7 +100,7 @@ def plot_sinc_ir(irs, cutoff):
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)]) axes[-1].set_xticks([i * half // 4 for i in range(-4, 5)])
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -130,7 +130,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
num_filts, num_fft = frs.shape num_filts, num_fft = frs.shape
num_ticks = num_filts + 1 if band else num_filts num_ticks = num_filts + 1 if band else num_filts
fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(6.4, 4.8 * 1.5)) fig, axes = plt.subplots(num_filts, 1, sharex=True, sharey=True, figsize=(9.6, 8))
for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors): for ax, fr, coff, color in zip(axes, frs, cutoff, plt.cm.tab10.colors):
ax.grid(True) ax.grid(True)
ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}") ax.semilogy(fr, color=color, zorder=4, label=f"Cutoff: {coff}")
...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False): ...@@ -146,7 +146,7 @@ def plot_sinc_fr(frs, cutoff, band=False):
"Frequency response of sinc low-pass filter for different cut-off frequencies\n" "Frequency response of sinc low-pass filter for different cut-off frequencies\n"
"(Frequencies are relative to Nyquist frequency)" "(Frequencies are relative to Nyquist frequency)"
) )
plt.tight_layout() fig.tight_layout()
###################################################################### ######################################################################
...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048): ...@@ -275,7 +275,7 @@ def plot_ir(magnitudes, ir, num_fft=2048):
axes[i].grid(True) axes[i].grid(True)
axes[1].set(title="Frequency Response") axes[1].set(title="Frequency Response")
axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency") axes[2].set(title="Frequency Response (log-scale)", xlabel="Frequency")
axes[2].legend(loc="lower right") axes[2].legend(loc="center right")
fig.tight_layout() fig.tight_layout()
......
...@@ -56,16 +56,11 @@ print(device) ...@@ -56,16 +56,11 @@ print(device)
# First we import the necessary packages, and fetch data that we work on. # First we import the necessary packages, and fetch data that we work on.
# #
# %matplotlib inline
from dataclasses import dataclass from dataclasses import dataclass
import IPython import IPython
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
torch.random.manual_seed(0) torch.random.manual_seed(0)
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
...@@ -99,17 +94,22 @@ with torch.inference_mode(): ...@@ -99,17 +94,22 @@ with torch.inference_mode():
emission = emissions[0].cpu().detach() emission = emissions[0].cpu().detach()
print(labels)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
print(labels)
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plt.show()
def plot():
plt.imshow(emission.T)
plt.colorbar()
plt.title("Frame-wise class probability")
plt.xlabel("Time")
plt.ylabel("Labels")
plot()
###################################################################### ######################################################################
# Generate alignment probability (trellis) # Generate alignment probability (trellis)
...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens) ...@@ -181,12 +181,17 @@ trellis = get_trellis(emission, tokens)
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
plt.imshow(trellis.T, origin="lower")
plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3)) def plot():
plt.colorbar() plt.imshow(trellis.T, origin="lower")
plt.show() plt.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
plt.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
plt.colorbar()
plot()
###################################################################### ######################################################################
# In the above visualization, we can see that there is a trace of high # In the above visualization, we can see that there is a trace of high
...@@ -266,7 +271,9 @@ for p in path: ...@@ -266,7 +271,9 @@ for p in path:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_path(trellis, path): def plot_trellis_with_path(trellis, path):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path): ...@@ -277,10 +284,14 @@ def plot_trellis_with_path(trellis, path):
plot_trellis_with_path(trellis, path) plot_trellis_with_path(trellis, path)
plt.title("The path found by backtracking") plt.title("The path found by backtracking")
plt.show()
###################################################################### ######################################################################
# Looking good. Now this path contains repetations for the same labels, so # Looking good.
######################################################################
# Segment the path
# ----------------
# Now this path contains repetations for the same labels, so
# let’s merge them to make it close to the original transcript. # let’s merge them to make it close to the original transcript.
# #
# When merging the multiple path points, we simply take the average # When merging the multiple path points, we simply take the average
...@@ -297,7 +308,7 @@ class Segment: ...@@ -297,7 +308,7 @@ class Segment:
score: float score: float
def __repr__(self): def __repr__(self):
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" return f"{self.label} ({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
@property @property
def length(self): def length(self):
...@@ -330,7 +341,9 @@ for seg in segments: ...@@ -330,7 +341,9 @@ for seg in segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_trellis_with_segments(trellis, segments, transcript): def plot_trellis_with_segments(trellis, segments, transcript):
# To plot trellis with path, we take advantage of 'nan' value # To plot trellis with path, we take advantage of 'nan' value
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -338,15 +351,14 @@ def plot_trellis_with_segments(trellis, segments, transcript):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
ax1.set_title("Path, label and probability for each label") ax1.set_title("Path, label and probability for each label")
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([])
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7), weight="bold") ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3)) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
ax2.set_title("Label probability with and without repetation") ax2.set_title("Label probability with and without repetation")
xs, hs, ws = [], [], [] xs, hs, ws = [], [], []
...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -355,7 +367,7 @@ def plot_trellis_with_segments(trellis, segments, transcript):
xs.append((seg.end + seg.start) / 2 + 0.4) xs.append((seg.end + seg.start) / 2 + 0.4)
hs.append(seg.score) hs.append(seg.score)
ws.append(seg.end - seg.start) ws.append(seg.end - seg.start)
ax2.annotate(seg.label, (seg.start + 0.8, -0.07), weight="bold") ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black") ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")
xs, hs = [], [] xs, hs = [], []
...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript): ...@@ -367,17 +379,21 @@ def plot_trellis_with_segments(trellis, segments, transcript):
ax2.bar(xs, hs, width=0.5, alpha=0.5) ax2.bar(xs, hs, width=0.5, alpha=0.5)
ax2.axhline(0, color="black") ax2.axhline(0, color="black")
ax2.set_xlim(ax1.get_xlim()) ax2.grid(True, axis="y")
ax2.set_ylim(-0.1, 1.1) ax2.set_ylim(-0.1, 1.1)
fig.tight_layout()
plot_trellis_with_segments(trellis, segments, transcript) plot_trellis_with_segments(trellis, segments, transcript)
plt.tight_layout()
plt.show()
###################################################################### ######################################################################
# Looks good. Now let’s merge the words. The Wav2Vec2 model uses ``'|'`` # Looks good.
######################################################################
# Merge the segments into words
# -----------------------------
# Now let’s merge the words. The Wav2Vec2 model uses ``'|'``
# as the word boundary, so we merge the segments before each occurance of # as the word boundary, so we merge the segments before each occurance of
# ``'|'``. # ``'|'``.
# #
...@@ -410,16 +426,16 @@ for word in word_segments: ...@@ -410,16 +426,16 @@ for word in word_segments:
################################################################################ ################################################################################
# Visualization # Visualization
################################################################################ # ~~~~~~~~~~~~~
def plot_alignments(trellis, segments, word_segments, waveform): def plot_alignments(trellis, segments, word_segments, waveform):
trellis_with_path = trellis.clone() trellis_with_path = trellis.clone()
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
trellis_with_path[seg.start : seg.end, i] = float("nan") trellis_with_path[seg.start : seg.end, i] = float("nan")
fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(16, 9.5)) fig, [ax1, ax2] = plt.subplots(2, 1)
ax1.imshow(trellis_with_path.T, origin="lower") ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
ax1.set_xticks([]) ax1.set_xticks([])
ax1.set_yticks([]) ax1.set_yticks([])
...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -429,8 +445,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
for i, seg in enumerate(segments): for i, seg in enumerate(segments):
if seg.label != "|": if seg.label != "|":
ax1.annotate(seg.label, (seg.start, i - 0.7)) ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), fontsize=8) ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")
# The original waveform # The original waveform
ratio = waveform.size(0) / trellis.size(0) ratio = waveform.size(0) / trellis.size(0)
...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform): ...@@ -450,6 +466,7 @@ def plot_alignments(trellis, segments, word_segments, waveform):
ax2.set_yticks([]) ax2.set_yticks([])
ax2.set_ylim(-1.0, 1.0) ax2.set_ylim(-1.0, 1.0)
ax2.set_xlim(0, waveform.size(-1)) ax2.set_xlim(0, waveform.size(-1))
fig.tight_layout()
plot_alignments( plot_alignments(
...@@ -458,7 +475,6 @@ plot_alignments( ...@@ -458,7 +475,6 @@ plot_alignments(
word_segments, word_segments,
waveform[0], waveform[0],
) )
plt.show()
################################################################################ ################################################################################
......
...@@ -162,11 +162,10 @@ def separate_sources( ...@@ -162,11 +162,10 @@ def separate_sources(
def plot_spectrogram(stft, title="Spectrogram"): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) _, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto") axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor ...@@ -252,7 +251,7 @@ def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor
"SDR score is:", "SDR score is:",
separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(), separation.bss_eval_sources(original_source.detach().numpy(), predicted_source.detach().numpy())[0].mean(),
) )
plot_spectrogram(stft(predicted_source)[0], f"Spectrogram {source}") plot_spectrogram(stft(predicted_source)[0], f"Spectrogram - {source}")
return Audio(predicted_source, rate=sample_rate) return Audio(predicted_source, rate=sample_rate)
...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu() ...@@ -294,7 +293,7 @@ mix_spec = mixture[:, frame_start:frame_end].cpu()
# #
# Mixture Clip # Mixture Clip
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture") plot_spectrogram(stft(mix_spec)[0], "Spectrogram - Mixture")
Audio(mix_spec, rate=sample_rate) Audio(mix_spec, rate=sample_rate)
###################################################################### ######################################################################
......
...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav") ...@@ -98,23 +98,21 @@ SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")
# #
def plot_spectrogram(stft, title="Spectrogram", xlim=None): def plot_spectrogram(stft, title="Spectrogram"):
magnitude = stft.abs() magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy() spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto") img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def plot_mask(mask, title="Mask", xlim=None): def plot_mask(mask, title="Mask"):
mask = mask.numpy() mask = mask.numpy()
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto") img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title) axis.set_title(title)
plt.colorbar(img, ax=axis) plt.colorbar(img, ax=axis)
plt.show()
def si_snr(estimate, reference, epsilon=1e-8): def si_snr(estimate, reference, epsilon=1e-8):
......
...@@ -33,12 +33,9 @@ print(torchaudio.__version__) ...@@ -33,12 +33,9 @@ print(torchaudio.__version__)
import os import os
import time import time
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchaudio.io import StreamReader from torchaudio.io import StreamReader
matplotlib.rcParams["image.interpolation"] = "none"
###################################################################### ######################################################################
# #
# Check the prerequisites # Check the prerequisites
......
...@@ -160,8 +160,7 @@ for i, feats in enumerate(features): ...@@ -160,8 +160,7 @@ for i, feats in enumerate(features):
ax[i].set_title(f"Feature from transformer layer {i+1}") ax[i].set_title(f"Feature from transformer layer {i+1}")
ax[i].set_xlabel("Feature dimension") ax[i].set_xlabel("Feature dimension")
ax[i].set_ylabel("Frame (time-axis)") ax[i].set_ylabel("Frame (time-axis)")
plt.tight_layout() fig.tight_layout()
plt.show()
###################################################################### ######################################################################
...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest") ...@@ -190,7 +189,7 @@ plt.imshow(emission[0].cpu().T, interpolation="nearest")
plt.title("Classification result") plt.title("Classification result")
plt.xlabel("Frame (time-axis)") plt.xlabel("Frame (time-axis)")
plt.ylabel("Class") plt.ylabel("Class")
plt.show() plt.tight_layout()
print("Class labels:", bundle.get_labels()) print("Class labels:", bundle.get_labels())
......
...@@ -82,6 +82,7 @@ try: ...@@ -82,6 +82,7 @@ try:
from pystoi import stoi from pystoi import stoi
from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE from torchaudio.pipelines import SQUIM_OBJECTIVE, SQUIM_SUBJECTIVE
except ImportError: except ImportError:
try:
import google.colab # noqa: F401 import google.colab # noqa: F401
print( print(
...@@ -95,6 +96,9 @@ except ImportError: ...@@ -95,6 +96,9 @@ except ImportError:
!pip3 install pystoi !pip3 install pystoi
""" """
) )
except Exception:
pass
raise
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8): ...@@ -128,27 +132,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
return si_snr.item() return si_snr.item()
def plot_waveform(waveform, title): def plot(waveform, title, sample_rate=16000):
wav_numpy = waveform.numpy() wav_numpy = waveform.numpy()
sample_size = waveform.shape[1] sample_size = waveform.shape[1]
time_axis = torch.arange(0, sample_size) / 16000 time_axis = torch.arange(0, sample_size) / sample_rate
figure, axes = plt.subplots(1, 1) figure, axes = plt.subplots(2, 1)
axes = figure.gca() axes[0].plot(time_axis, wav_numpy[0], linewidth=1)
axes.plot(time_axis, wav_numpy[0], linewidth=1) axes[0].grid(True)
axes.grid(True) axes[1].specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title) figure.suptitle(title)
plt.show(block=False)
def plot_specgram(waveform, sample_rate, title):
wav_numpy = waveform.numpy()
figure, axes = plt.subplots(1, 1)
axes = figure.gca()
axes.specgram(wav_numpy[0], Fs=sample_rate)
figure.suptitle(title)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000) ...@@ -238,32 +232,28 @@ Audio(WAVEFORM_DISTORTED.numpy()[1], rate=16000)
# Visualize speech sample # Visualize speech sample
# #
plot_waveform(WAVEFORM_SPEECH, "Clean Speech") plot(WAVEFORM_SPEECH, "Clean Speech")
plot_specgram(WAVEFORM_SPEECH, 16000, "Clean Speech Spectrogram")
###################################################################### ######################################################################
# Visualize noise sample # Visualize noise sample
# #
plot_waveform(WAVEFORM_NOISE, "Noise") plot(WAVEFORM_NOISE, "Noise")
plot_specgram(WAVEFORM_NOISE, 16000, "Noise Spectrogram")
###################################################################### ######################################################################
# Visualize distorted speech with 20dB SNR # Visualize distorted speech with 20dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR") plot(WAVEFORM_DISTORTED[0:1], f"Distorted Speech with {snr_dbs[0]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[0:1], 16000, f"Distorted Speech with {snr_dbs[0]}dB SNR")
###################################################################### ######################################################################
# Visualize distorted speech with -5dB SNR # Visualize distorted speech with -5dB SNR
# #
plot_waveform(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR") plot(WAVEFORM_DISTORTED[1:2], f"Distorted Speech with {snr_dbs[1]}dB SNR")
plot_specgram(WAVEFORM_DISTORTED[1:2], 16000, f"Distorted Speech with {snr_dbs[1]}dB SNR")
###################################################################### ######################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment