Commit 8e20d546 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update audio feature extraction tutorial (#2391)

Summary:
- Adopt torchaudio.utils.download_asset to simplify asset management.
- Break down the first section about helper functions.
- Reduce the number of helper functions

Pull Request resolved: https://github.com/pytorch/audio/pull/2391

Reviewed By: carolineechen, nateanl

Differential Revision: D36885626

Pulled By: mthrok

fbshipit-source-id: 1306f22ab70ab1e7f74ed7e43bf43150015448b6
parent f0bc00c9
......@@ -11,20 +11,10 @@ domain. They are available in ``torchaudio.functional`` and
They are stateless.
``transforms`` implements features as objects,
using implementations from ``functional`` and ``torch.nn.Module``. Because all
transforms are subclasses of ``torch.nn.Module``, they can be serialized
using TorchScript.
For the complete list of available features, please refer to the
documentation. In this tutorial, we will look into converting between the
time domain and frequency domain (``Spectrogram``, ``GriffinLim``,
``MelSpectrogram``).
using implementations from ``functional`` and ``torch.nn.Module``.
They can be serialized using TorchScript.
"""
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio librosa
import torch
import torchaudio
import torchaudio.functional as F
......@@ -34,131 +24,51 @@ print(torch.__version__)
print(torchaudio.__version__)
######################################################################
# Preparing data and utility functions (skip this section)
# --------------------------------------------------------
# Preparation
# -----------
#
# @title Prepare data and utility functions. {display-mode: "form"}
# @markdown
# @markdown You do not need to look into this cell.
# @markdown Just execute once and you are good to go.
# @markdown
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/),
# @markdown which is licensed under Creative Commos BY 4.0.
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os
# .. note::
#
# When running this tutorial in Google Colab, install the required packages
#
# .. code::
#
# !pip install librosa
#
from IPython.display import Audio
import librosa
import matplotlib.pyplot as plt
import requests
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
from torchaudio.utils import download_asset
torch.random.manual_seed(0)
def _get_sample(path, resample=None):
effects = [["remix", "1"]]
if resample:
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
SAMPLE_SPEECH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
def get_speech_sample(*, resample=None):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample)
def plot_waveform(waveform, sr, title="Waveform"):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sr
def print_stats(waveform, sample_rate=None, src=None):
if src:
print("-" * 10)
print("Source:", src)
print("-" * 10)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
figure, axes = plt.subplots(num_channels, 1)
axes.plot(time_axis, waveform[0], linewidth=1)
axes.grid(True)
figure.suptitle(title)
plt.show(block=False)
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel)
axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect)
if xmax:
axs.set_xlim((0, xmax))
im = axs.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto")
fig.colorbar(im, ax=axs)
plt.show(block=False)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_mel_fbank(fbank, title=None):
def plot_fbank(fbank, title=None):
fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Filter bank")
axs.imshow(fbank, aspect="auto")
......@@ -167,44 +77,18 @@ def plot_mel_fbank(fbank, title=None):
plt.show(block=False)
def plot_pitch(waveform, sample_rate, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
figure, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
######################################################################
# Overview of audio features
# --------------------------
#
# The following diagram shows the relationship between common audio features
# and torchaudio APIs to generate them.
#
# .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/torchaudio_feature_extractions.png
#
# For the complete list of available features, please refer to the
# documentation.
#
######################################################################
......@@ -215,14 +99,20 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
# you can use :py:func:`torchaudio.transforms.Spectrogram`.
#
SPEECH_WAVEFORM, SAMPLE_RATE = torchaudio.load(SAMPLE_SPEECH)
plot_waveform(SPEECH_WAVEFORM, SAMPLE_RATE, title="Original waveform")
Audio(SPEECH_WAVEFORM.numpy(), rate=SAMPLE_RATE)
waveform, sample_rate = get_speech_sample()
######################################################################
#
n_fft = 1024
win_length = None
hop_length = 512
# define transformation
# Define transform
spectrogram = T.Spectrogram(
n_fft=n_fft,
win_length=win_length,
......@@ -231,10 +121,16 @@ spectrogram = T.Spectrogram(
pad_mode="reflect",
power=2.0,
)
# Perform transformation
spec = spectrogram(waveform)
print_stats(spec)
######################################################################
#
# Perform transform
spec = spectrogram(SPEECH_WAVEFORM)
######################################################################
#
plot_spectrogram(spec[0], title="torchaudio")
######################################################################
......@@ -244,11 +140,7 @@ plot_spectrogram(spec[0], title="torchaudio")
# To recover a waveform from a spectrogram, you can use ``GriffinLim``.
#
torch.random.manual_seed(0)
waveform, sample_rate = get_speech_sample()
plot_waveform(waveform, sample_rate, title="Original")
play_audio(waveform, sample_rate)
n_fft = 1024
win_length = None
......@@ -258,17 +150,27 @@ spec = T.Spectrogram(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
)(waveform)
)(SPEECH_WAVEFORM)
######################################################################
#
griffin_lim = T.GriffinLim(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
waveform = griffin_lim(spec)
plot_waveform(waveform, sample_rate, title="Reconstructed")
play_audio(waveform, sample_rate)
######################################################################
#
reconstructed_waveform = griffin_lim(spec)
######################################################################
#
plot_waveform(reconstructed_waveform, SAMPLE_RATE, title="Reconstructed")
Audio(reconstructed_waveform, rate=SAMPLE_RATE)
######################################################################
# Mel Filter Bank
......@@ -281,7 +183,6 @@ play_audio(waveform, sample_rate)
# equivalent transform in :py:func:`torchaudio.transforms`.
#
n_fft = 256
n_mels = 64
sample_rate = 6000
......@@ -294,7 +195,11 @@ mel_filters = F.melscale_fbanks(
sample_rate=sample_rate,
norm="slaney",
)
plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
######################################################################
#
plot_fbank(mel_filters, "Mel Filter Bank - torchaudio")
######################################################################
# Comparison against librosa
......@@ -304,7 +209,6 @@ plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
# with ``librosa``.
#
mel_filters_librosa = librosa.filters.mel(
sr=sample_rate,
n_fft=n_fft,
......@@ -315,7 +219,10 @@ mel_filters_librosa = librosa.filters.mel(
htk=True,
).T
plot_mel_fbank(mel_filters_librosa, "Mel Filter Bank - librosa")
######################################################################
#
plot_fbank(mel_filters_librosa, "Mel Filter Bank - librosa")
mse = torch.square(mel_filters - mel_filters_librosa).mean().item()
print("Mean Square Difference: ", mse)
......@@ -330,9 +237,6 @@ print("Mean Square Difference: ", mse)
# this functionality.
#
waveform, sample_rate = get_speech_sample()
n_fft = 1024
win_length = None
hop_length = 512
......@@ -352,7 +256,11 @@ mel_spectrogram = T.MelSpectrogram(
mel_scale="htk",
)
melspec = mel_spectrogram(waveform)
melspec = mel_spectrogram(SPEECH_WAVEFORM)
######################################################################
#
plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel freq")
######################################################################
......@@ -363,9 +271,8 @@ plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel fr
# spectrograms with ``librosa``.
#
melspec_librosa = librosa.feature.melspectrogram(
y=waveform.numpy()[0],
y=SPEECH_WAVEFORM.numpy()[0],
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
......@@ -377,6 +284,10 @@ melspec_librosa = librosa.feature.melspectrogram(
norm="slaney",
htk=True,
)
######################################################################
#
plot_spectrogram(melspec_librosa, title="MelSpectrogram - librosa", ylabel="mel freq")
mse = torch.square(melspec - melspec_librosa).mean().item()
......@@ -387,8 +298,6 @@ print("Mean Square Difference: ", mse)
# ----
#
waveform, sample_rate = get_speech_sample()
n_fft = 2048
win_length = None
hop_length = 512
......@@ -406,18 +315,20 @@ mfcc_transform = T.MFCC(
},
)
mfcc = mfcc_transform(waveform)
mfcc = mfcc_transform(SPEECH_WAVEFORM)
######################################################################
#
plot_spectrogram(mfcc[0])
######################################################################
# Comparing against librosa
# ~~~~~~~~~~~~~~~~~~~~~~~~~
# Comparison against librosa
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
melspec = librosa.feature.melspectrogram(
y=waveform.numpy()[0],
y=SPEECH_WAVEFORM.numpy()[0],
sr=sample_rate,
n_fft=n_fft,
win_length=win_length,
......@@ -434,22 +345,65 @@ mfcc_librosa = librosa.feature.mfcc(
norm="ortho",
)
######################################################################
#
plot_spectrogram(mfcc_librosa)
mse = torch.square(mfcc - mfcc_librosa).mean().item()
print("Mean Square Difference: ", mse)
######################################################################
# LFCC
# ----
#
n_fft = 2048
win_length = None
hop_length = 512
n_lfcc = 256
lfcc_transform = T.LFCC(
sample_rate=sample_rate,
n_lfcc=n_lfcc,
speckwargs={
"n_fft": n_fft,
"win_length": win_length,
"hop_length": hop_length,
},
)
lfcc = lfcc_transform(SPEECH_WAVEFORM)
plot_spectrogram(lfcc[0])
######################################################################
# Pitch
# -----
#
pitch = F.detect_pitch_frequency(SPEECH_WAVEFORM, SAMPLE_RATE)
######################################################################
#
def plot_pitch(waveform, sr, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sr
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
waveform, sample_rate = get_speech_sample()
pitch = F.detect_pitch_frequency(waveform, sample_rate)
plot_pitch(waveform, sample_rate, pitch)
play_audio(waveform, sample_rate)
plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
######################################################################
# Kaldi Pitch (beta)
......@@ -471,11 +425,33 @@ play_audio(waveform, sample_rate)
# [`paper <https://danielpovey.com/files/2014_icassp_pitch.pdf>`__]
#
pitch_feature = F.compute_kaldi_pitch(SPEECH_WAVEFORM, SAMPLE_RATE)
pitch, nfcc = pitch_feature[..., 0], pitch_feature[..., 1]
waveform, sample_rate = get_speech_sample(resample=16000)
######################################################################
#
def plot_kaldi_pitch(waveform, sr, pitch, nfcc):
_, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sr
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
pitch_feature = F.compute_kaldi_pitch(waveform, sample_rate)
pitch, nfcc = pitch_feature[..., 0], pitch_feature[..., 1]
plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc)
play_audio(waveform, sample_rate)
plot_kaldi_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch, nfcc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment