Commit 8e20d546 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update audio feature extraction tutorial (#2391)

Summary:
- Adopt torchaudio.utils.download_asset to simplify asset management.
- Break down the first section about helper functions.
- Reduce the number of helper functions

Pull Request resolved: https://github.com/pytorch/audio/pull/2391

Reviewed By: carolineechen, nateanl

Differential Revision: D36885626

Pulled By: mthrok

fbshipit-source-id: 1306f22ab70ab1e7f74ed7e43bf43150015448b6
parent f0bc00c9
...@@ -11,20 +11,10 @@ domain. They are available in ``torchaudio.functional`` and ...@@ -11,20 +11,10 @@ domain. They are available in ``torchaudio.functional`` and
They are stateless. They are stateless.
``transforms`` implements features as objects, ``transforms`` implements features as objects,
using implementations from ``functional`` and ``torch.nn.Module``. Because all using implementations from ``functional`` and ``torch.nn.Module``.
transforms are subclasses of ``torch.nn.Module``, they can be serialized They can be serialized using TorchScript.
using TorchScript.
For the complete list of available features, please refer to the
documentation. In this tutorial, we will look into converting between the
time domain and frequency domain (``Spectrogram``, ``GriffinLim``,
``MelSpectrogram``).
""" """
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio librosa
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
...@@ -34,131 +24,51 @@ print(torch.__version__) ...@@ -34,131 +24,51 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section) # Preparation
# -------------------------------------------------------- # -----------
# #
# .. note::
# @title Prepare data and utility functions. {display-mode: "form"} #
# @markdown # When running this tutorial in Google Colab, install the required packages
# @markdown You do not need to look into this cell. #
# @markdown Just execute once and you are good to go. # .. code::
# @markdown #
# @markdown In this tutorial, we will use a speech data from [VOiCES dataset](https://iqtlabs.github.io/voices/), # !pip install librosa
# @markdown which is licensed under Creative Commos BY 4.0. #
from IPython.display import Audio
# -------------------------------------------------------------------------------
# Preparation of data and helper functions.
# -------------------------------------------------------------------------------
import os
import librosa import librosa
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import requests from torchaudio.utils import download_asset
from IPython.display import Audio, display
_SAMPLE_DIR = "_assets"
SAMPLE_WAV_SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
SAMPLE_WAV_SPEECH_PATH = os.path.join(_SAMPLE_DIR, "speech.wav")
os.makedirs(_SAMPLE_DIR, exist_ok=True)
def _fetch_data():
uri = [
(SAMPLE_WAV_SPEECH_URL, SAMPLE_WAV_SPEECH_PATH),
]
for url, path in uri:
with open(path, "wb") as file_:
file_.write(requests.get(url).content)
_fetch_data()
torch.random.manual_seed(0)
def _get_sample(path, resample=None): SAMPLE_SPEECH = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
effects = [["remix", "1"]]
if resample:
effects.extend(
[
["lowpass", f"{resample // 2}"],
["rate", f"{resample}"],
]
)
return torchaudio.sox_effects.apply_effects_file(path, effects=effects)
def get_speech_sample(*, resample=None): def plot_waveform(waveform, sr, title="Waveform"):
return _get_sample(SAMPLE_WAV_SPEECH_PATH, resample=resample) waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sr
def print_stats(waveform, sample_rate=None, src=None): figure, axes = plt.subplots(num_channels, 1)
if src: axes.plot(time_axis, waveform[0], linewidth=1)
print("-" * 10) axes.grid(True)
print("Source:", src) figure.suptitle(title)
print("-" * 10) plt.show(block=False)
if sample_rate:
print("Sample Rate:", sample_rate)
print("Shape:", tuple(waveform.shape))
print("Dtype:", waveform.dtype)
print(f" - Max: {waveform.max().item():6.3f}")
print(f" - Min: {waveform.min().item():6.3f}")
print(f" - Mean: {waveform.mean().item():6.3f}")
print(f" - Std Dev: {waveform.std().item():6.3f}")
print()
print(waveform)
print()
def plot_spectrogram(spec, title=None, ylabel="freq_bin", aspect="auto", xmax=None): def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
fig, axs = plt.subplots(1, 1) fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Spectrogram (db)") axs.set_title(title or "Spectrogram (db)")
axs.set_ylabel(ylabel) axs.set_ylabel(ylabel)
axs.set_xlabel("frame") axs.set_xlabel("frame")
im = axs.imshow(librosa.power_to_db(spec), origin="lower", aspect=aspect) im = axs.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto")
if xmax:
axs.set_xlim((0, xmax))
fig.colorbar(im, ax=axs) fig.colorbar(im, ax=axs)
plt.show(block=False) plt.show(block=False)
def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None): def plot_fbank(fbank, title=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
time_axis = torch.arange(0, num_frames) / sample_rate
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].plot(time_axis, waveform[c], linewidth=1)
axes[c].grid(True)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
if ylim:
axes[c].set_ylim(ylim)
figure.suptitle(title)
plt.show(block=False)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_mel_fbank(fbank, title=None):
fig, axs = plt.subplots(1, 1) fig, axs = plt.subplots(1, 1)
axs.set_title(title or "Filter bank") axs.set_title(title or "Filter bank")
axs.imshow(fbank, aspect="auto") axs.imshow(fbank, aspect="auto")
...@@ -167,44 +77,18 @@ def plot_mel_fbank(fbank, title=None): ...@@ -167,44 +77,18 @@ def plot_mel_fbank(fbank, title=None):
plt.show(block=False) plt.show(block=False)
def plot_pitch(waveform, sample_rate, pitch): ######################################################################
figure, axis = plt.subplots(1, 1) # Overview of audio features
axis.set_title("Pitch Feature") # --------------------------
axis.grid(True) #
# The following diagram shows the relationship between common audio features
end_time = waveform.shape[1] / sample_rate # and torchaudio APIs to generate them.
time_axis = torch.linspace(0, end_time, waveform.shape[1]) #
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3) # .. image:: https://download.pytorch.org/torchaudio/tutorial-assets/torchaudio_feature_extractions.png
#
axis2 = axis.twinx() # For the complete list of available features, please refer to the
time_axis = torch.linspace(0, end_time, pitch.shape[1]) # documentation.
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") #
axis2.legend(loc=0)
plt.show(block=False)
def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
figure, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sample_rate
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
###################################################################### ######################################################################
...@@ -215,14 +99,20 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc): ...@@ -215,14 +99,20 @@ def plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc):
# you can use :py:func:`torchaudio.transforms.Spectrogram`. # you can use :py:func:`torchaudio.transforms.Spectrogram`.
# #
SPEECH_WAVEFORM, SAMPLE_RATE = torchaudio.load(SAMPLE_SPEECH)
waveform, sample_rate = get_speech_sample() plot_waveform(SPEECH_WAVEFORM, SAMPLE_RATE, title="Original waveform")
Audio(SPEECH_WAVEFORM.numpy(), rate=SAMPLE_RATE)
######################################################################
#
n_fft = 1024 n_fft = 1024
win_length = None win_length = None
hop_length = 512 hop_length = 512
# define transformation # Define transform
spectrogram = T.Spectrogram( spectrogram = T.Spectrogram(
n_fft=n_fft, n_fft=n_fft,
win_length=win_length, win_length=win_length,
...@@ -231,10 +121,16 @@ spectrogram = T.Spectrogram( ...@@ -231,10 +121,16 @@ spectrogram = T.Spectrogram(
pad_mode="reflect", pad_mode="reflect",
power=2.0, power=2.0,
) )
# Perform transformation
spec = spectrogram(waveform)
print_stats(spec) ######################################################################
#
# Perform transform
spec = spectrogram(SPEECH_WAVEFORM)
######################################################################
#
plot_spectrogram(spec[0], title="torchaudio") plot_spectrogram(spec[0], title="torchaudio")
###################################################################### ######################################################################
...@@ -244,11 +140,7 @@ plot_spectrogram(spec[0], title="torchaudio") ...@@ -244,11 +140,7 @@ plot_spectrogram(spec[0], title="torchaudio")
# To recover a waveform from a spectrogram, you can use ``GriffinLim``. # To recover a waveform from a spectrogram, you can use ``GriffinLim``.
# #
torch.random.manual_seed(0) torch.random.manual_seed(0)
waveform, sample_rate = get_speech_sample()
plot_waveform(waveform, sample_rate, title="Original")
play_audio(waveform, sample_rate)
n_fft = 1024 n_fft = 1024
win_length = None win_length = None
...@@ -258,17 +150,27 @@ spec = T.Spectrogram( ...@@ -258,17 +150,27 @@ spec = T.Spectrogram(
n_fft=n_fft, n_fft=n_fft,
win_length=win_length, win_length=win_length,
hop_length=hop_length, hop_length=hop_length,
)(waveform) )(SPEECH_WAVEFORM)
######################################################################
#
griffin_lim = T.GriffinLim( griffin_lim = T.GriffinLim(
n_fft=n_fft, n_fft=n_fft,
win_length=win_length, win_length=win_length,
hop_length=hop_length, hop_length=hop_length,
) )
waveform = griffin_lim(spec)
plot_waveform(waveform, sample_rate, title="Reconstructed") ######################################################################
play_audio(waveform, sample_rate) #
reconstructed_waveform = griffin_lim(spec)
######################################################################
#
plot_waveform(reconstructed_waveform, SAMPLE_RATE, title="Reconstructed")
Audio(reconstructed_waveform, rate=SAMPLE_RATE)
###################################################################### ######################################################################
# Mel Filter Bank # Mel Filter Bank
...@@ -281,7 +183,6 @@ play_audio(waveform, sample_rate) ...@@ -281,7 +183,6 @@ play_audio(waveform, sample_rate)
# equivalent transform in :py:func:`torchaudio.transforms`. # equivalent transform in :py:func:`torchaudio.transforms`.
# #
n_fft = 256 n_fft = 256
n_mels = 64 n_mels = 64
sample_rate = 6000 sample_rate = 6000
...@@ -294,7 +195,11 @@ mel_filters = F.melscale_fbanks( ...@@ -294,7 +195,11 @@ mel_filters = F.melscale_fbanks(
sample_rate=sample_rate, sample_rate=sample_rate,
norm="slaney", norm="slaney",
) )
plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
######################################################################
#
plot_fbank(mel_filters, "Mel Filter Bank - torchaudio")
###################################################################### ######################################################################
# Comparison against librosa # Comparison against librosa
...@@ -304,7 +209,6 @@ plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio") ...@@ -304,7 +209,6 @@ plot_mel_fbank(mel_filters, "Mel Filter Bank - torchaudio")
# with ``librosa``. # with ``librosa``.
# #
mel_filters_librosa = librosa.filters.mel( mel_filters_librosa = librosa.filters.mel(
sr=sample_rate, sr=sample_rate,
n_fft=n_fft, n_fft=n_fft,
...@@ -315,7 +219,10 @@ mel_filters_librosa = librosa.filters.mel( ...@@ -315,7 +219,10 @@ mel_filters_librosa = librosa.filters.mel(
htk=True, htk=True,
).T ).T
plot_mel_fbank(mel_filters_librosa, "Mel Filter Bank - librosa") ######################################################################
#
plot_fbank(mel_filters_librosa, "Mel Filter Bank - librosa")
mse = torch.square(mel_filters - mel_filters_librosa).mean().item() mse = torch.square(mel_filters - mel_filters_librosa).mean().item()
print("Mean Square Difference: ", mse) print("Mean Square Difference: ", mse)
...@@ -330,9 +237,6 @@ print("Mean Square Difference: ", mse) ...@@ -330,9 +237,6 @@ print("Mean Square Difference: ", mse)
# this functionality. # this functionality.
# #
waveform, sample_rate = get_speech_sample()
n_fft = 1024 n_fft = 1024
win_length = None win_length = None
hop_length = 512 hop_length = 512
...@@ -352,7 +256,11 @@ mel_spectrogram = T.MelSpectrogram( ...@@ -352,7 +256,11 @@ mel_spectrogram = T.MelSpectrogram(
mel_scale="htk", mel_scale="htk",
) )
melspec = mel_spectrogram(waveform) melspec = mel_spectrogram(SPEECH_WAVEFORM)
######################################################################
#
plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel freq") plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel freq")
###################################################################### ######################################################################
...@@ -363,9 +271,8 @@ plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel fr ...@@ -363,9 +271,8 @@ plot_spectrogram(melspec[0], title="MelSpectrogram - torchaudio", ylabel="mel fr
# spectrograms with ``librosa``. # spectrograms with ``librosa``.
# #
melspec_librosa = librosa.feature.melspectrogram( melspec_librosa = librosa.feature.melspectrogram(
y=waveform.numpy()[0], y=SPEECH_WAVEFORM.numpy()[0],
sr=sample_rate, sr=sample_rate,
n_fft=n_fft, n_fft=n_fft,
hop_length=hop_length, hop_length=hop_length,
...@@ -377,6 +284,10 @@ melspec_librosa = librosa.feature.melspectrogram( ...@@ -377,6 +284,10 @@ melspec_librosa = librosa.feature.melspectrogram(
norm="slaney", norm="slaney",
htk=True, htk=True,
) )
######################################################################
#
plot_spectrogram(melspec_librosa, title="MelSpectrogram - librosa", ylabel="mel freq") plot_spectrogram(melspec_librosa, title="MelSpectrogram - librosa", ylabel="mel freq")
mse = torch.square(melspec - melspec_librosa).mean().item() mse = torch.square(melspec - melspec_librosa).mean().item()
...@@ -387,8 +298,6 @@ print("Mean Square Difference: ", mse) ...@@ -387,8 +298,6 @@ print("Mean Square Difference: ", mse)
# ---- # ----
# #
waveform, sample_rate = get_speech_sample()
n_fft = 2048 n_fft = 2048
win_length = None win_length = None
hop_length = 512 hop_length = 512
...@@ -406,18 +315,20 @@ mfcc_transform = T.MFCC( ...@@ -406,18 +315,20 @@ mfcc_transform = T.MFCC(
}, },
) )
mfcc = mfcc_transform(waveform) mfcc = mfcc_transform(SPEECH_WAVEFORM)
######################################################################
#
plot_spectrogram(mfcc[0]) plot_spectrogram(mfcc[0])
###################################################################### ######################################################################
# Comparing against librosa # Comparison against librosa
# ~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
melspec = librosa.feature.melspectrogram( melspec = librosa.feature.melspectrogram(
y=waveform.numpy()[0], y=SPEECH_WAVEFORM.numpy()[0],
sr=sample_rate, sr=sample_rate,
n_fft=n_fft, n_fft=n_fft,
win_length=win_length, win_length=win_length,
...@@ -434,22 +345,65 @@ mfcc_librosa = librosa.feature.mfcc( ...@@ -434,22 +345,65 @@ mfcc_librosa = librosa.feature.mfcc(
norm="ortho", norm="ortho",
) )
######################################################################
#
plot_spectrogram(mfcc_librosa) plot_spectrogram(mfcc_librosa)
mse = torch.square(mfcc - mfcc_librosa).mean().item() mse = torch.square(mfcc - mfcc_librosa).mean().item()
print("Mean Square Difference: ", mse) print("Mean Square Difference: ", mse)
######################################################################
# LFCC
# ----
#
n_fft = 2048
win_length = None
hop_length = 512
n_lfcc = 256
lfcc_transform = T.LFCC(
sample_rate=sample_rate,
n_lfcc=n_lfcc,
speckwargs={
"n_fft": n_fft,
"win_length": win_length,
"hop_length": hop_length,
},
)
lfcc = lfcc_transform(SPEECH_WAVEFORM)
plot_spectrogram(lfcc[0])
###################################################################### ######################################################################
# Pitch # Pitch
# ----- # -----
# #
pitch = F.detect_pitch_frequency(SPEECH_WAVEFORM, SAMPLE_RATE)
######################################################################
#
def plot_pitch(waveform, sr, pitch):
figure, axis = plt.subplots(1, 1)
axis.set_title("Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sr
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, pitch.shape[1])
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis2.legend(loc=0)
plt.show(block=False)
waveform, sample_rate = get_speech_sample()
pitch = F.detect_pitch_frequency(waveform, sample_rate) plot_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch)
plot_pitch(waveform, sample_rate, pitch)
play_audio(waveform, sample_rate)
###################################################################### ######################################################################
# Kaldi Pitch (beta) # Kaldi Pitch (beta)
...@@ -471,11 +425,33 @@ play_audio(waveform, sample_rate) ...@@ -471,11 +425,33 @@ play_audio(waveform, sample_rate)
# [`paper <https://danielpovey.com/files/2014_icassp_pitch.pdf>`__] # [`paper <https://danielpovey.com/files/2014_icassp_pitch.pdf>`__]
# #
pitch_feature = F.compute_kaldi_pitch(SPEECH_WAVEFORM, SAMPLE_RATE)
pitch, nfcc = pitch_feature[..., 0], pitch_feature[..., 1]
waveform, sample_rate = get_speech_sample(resample=16000) ######################################################################
#
def plot_kaldi_pitch(waveform, sr, pitch, nfcc):
_, axis = plt.subplots(1, 1)
axis.set_title("Kaldi Pitch Feature")
axis.grid(True)
end_time = waveform.shape[1] / sr
time_axis = torch.linspace(0, end_time, waveform.shape[1])
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3)
time_axis = torch.linspace(0, end_time, pitch.shape[1])
ln1 = axis.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green")
axis.set_ylim((-1.3, 1.3))
axis2 = axis.twinx()
time_axis = torch.linspace(0, end_time, nfcc.shape[1])
ln2 = axis2.plot(time_axis, nfcc[0], linewidth=2, label="NFCC", color="blue", linestyle="--")
lns = ln1 + ln2
labels = [l.get_label() for l in lns]
axis.legend(lns, labels, loc=0)
plt.show(block=False)
pitch_feature = F.compute_kaldi_pitch(waveform, sample_rate)
pitch, nfcc = pitch_feature[..., 0], pitch_feature[..., 1]
plot_kaldi_pitch(waveform, sample_rate, pitch, nfcc) plot_kaldi_pitch(SPEECH_WAVEFORM, SAMPLE_RATE, pitch, nfcc)
play_audio(waveform, sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment