Commit fd2be89a authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Update audio resampling tutorial (#2386)

Summary:
- Replace mis-use of plot_specgram with plot_sweep, and remove plot_specgram
- Move `benchmark_resample` to later section

https://output.circle-artifacts.com/output/job/9f7af187-777d-4d75-840f-2630a36295b7/artifacts/0/docs/tutorials/audio_resampling_tutorial.html

Pull Request resolved: https://github.com/pytorch/audio/pull/2386

Reviewed By: carolineechen

Differential Revision: D36404403

Pulled By: mthrok

fbshipit-source-id: f9df8453e3f531bdc4549b0134e5dbba90653bf7
parent 8e20d546
...@@ -3,14 +3,9 @@ ...@@ -3,14 +3,9 @@
Audio Resampling Audio Resampling
================ ================
Here, we will walk through resampling audio waveforms using ``torchaudio``. This tutorial shows how to use torchaudio's resampling API.
""" """
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio librosa
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
...@@ -20,18 +15,18 @@ print(torch.__version__) ...@@ -20,18 +15,18 @@ print(torch.__version__)
print(torchaudio.__version__) print(torchaudio.__version__)
###################################################################### ######################################################################
# Preparing data and utility functions (skip this section) # Preparation
# -------------------------------------------------------- # -----------
# #
# First, we import the modules and define the helper functions.
# @title Prepare data and utility functions. {display-mode: "form"} #
# @markdown # .. note::
# @markdown You do not need to look into this cell. # When running this tutorial in Google Colab, install the required packages
# @markdown Just execute once and you are good to go. # with the following.
#
# ------------------------------------------------------------------------------- # .. code::
# Preparation of data and helper functions. #
# ------------------------------------------------------------------------------- # !pip install librosa
import math import math
import time import time
...@@ -41,12 +36,10 @@ import matplotlib.pyplot as plt ...@@ -41,12 +36,10 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from IPython.display import Audio, display from IPython.display import Audio, display
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
DEFAULT_OFFSET = 201 DEFAULT_OFFSET = 201
SWEEP_MAX_SAMPLE_RATE = 48000
DEFAULT_LOWPASS_FILTER_WIDTH = 6
DEFAULT_ROLLOFF = 0.99
DEFAULT_RESAMPLING_METHOD = "sinc_interpolation"
def _get_log_freq(sample_rate, max_sweep_rate, offset): def _get_log_freq(sample_rate, max_sweep_rate, offset):
...@@ -95,7 +88,7 @@ def plot_sweep( ...@@ -95,7 +88,7 @@ def plot_sweep(
waveform, waveform,
sample_rate, sample_rate,
title, title,
max_sweep_rate=SWEEP_MAX_SAMPLE_RATE, max_sweep_rate=48000,
offset=DEFAULT_OFFSET, offset=DEFAULT_OFFSET,
): ):
x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2] x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2]
...@@ -103,10 +96,10 @@ def plot_sweep( ...@@ -103,10 +96,10 @@ def plot_sweep(
time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2) time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2)
freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq] freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq]
freq_y = [f for f in freq if f >= 1000 and f in y_ticks and f <= sample_rate // 2] freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2]
figure, axis = plt.subplots(1, 1) figure, axis = plt.subplots(1, 1)
axis.specgram(waveform[0].numpy(), Fs=sample_rate) _, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate)
plt.xticks(time, freq_x) plt.xticks(time, freq_x)
plt.yticks(freq_y, freq_y) plt.yticks(freq_y, freq_y)
axis.set_xlabel("Original Signal Frequency (Hz, log scale)") axis.set_xlabel("Original Signal Frequency (Hz, log scale)")
...@@ -114,87 +107,10 @@ def plot_sweep( ...@@ -114,87 +107,10 @@ def plot_sweep(
axis.xaxis.grid(True, alpha=0.67) axis.xaxis.grid(True, alpha=0.67)
axis.yaxis.grid(True, alpha=0.67) axis.yaxis.grid(True, alpha=0.67)
figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)")
plt.colorbar(cax)
plt.show(block=True) plt.show(block=True)
def play_audio(waveform, sample_rate):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
if num_channels == 1:
display(Audio(waveform[0], rate=sample_rate))
elif num_channels == 2:
display(Audio((waveform[0], waveform[1]), rate=sample_rate))
else:
raise ValueError("Waveform with more than 2 channels are not supported.")
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
waveform = waveform.numpy()
num_channels, num_frames = waveform.shape
figure, axes = plt.subplots(num_channels, 1)
if num_channels == 1:
axes = [axes]
for c in range(num_channels):
axes[c].specgram(waveform[c], Fs=sample_rate)
if num_channels > 1:
axes[c].set_ylabel(f"Channel {c+1}")
if xlim:
axes[c].set_xlim(xlim)
figure.suptitle(title)
plt.show(block=False)
def benchmark_resample(
method,
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=DEFAULT_LOWPASS_FILTER_WIDTH,
rolloff=DEFAULT_ROLLOFF,
resampling_method=DEFAULT_RESAMPLING_METHOD,
beta=None,
librosa_type=None,
iters=5,
):
if method == "functional":
begin = time.time()
for _ in range(iters):
F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
)
elapsed = time.time() - begin
return elapsed / iters
elif method == "transforms":
resampler = T.Resample(
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
dtype=waveform.dtype,
)
begin = time.time()
for _ in range(iters):
resampler(waveform)
elapsed = time.time() - begin
return elapsed / iters
elif method == "librosa":
waveform_np = waveform.squeeze().numpy()
begin = time.time()
for _ in range(iters):
librosa.resample(waveform_np, orig_sr=sample_rate, target_sr=resample_rate, res_type=librosa_type)
elapsed = time.time() - begin
return elapsed / iters
###################################################################### ######################################################################
# Resampling Overview # Resampling Overview
# ------------------- # -------------------
...@@ -211,11 +127,14 @@ def benchmark_resample( ...@@ -211,11 +127,14 @@ def benchmark_resample(
# interpolation <https://ccrma.stanford.edu/~jos/resample/>`__ to compute # interpolation <https://ccrma.stanford.edu/~jos/resample/>`__ to compute
# signal values at arbitrary time steps. The implementation involves # signal values at arbitrary time steps. The implementation involves
# convolution, so we can take advantage of GPU / multithreading for # convolution, so we can take advantage of GPU / multithreading for
# performance improvements. When using resampling in multiple # performance improvements.
# subprocesses, such as data loading with multiple worker processes, your #
# application might create more threads than your system can handle # .. note::
# efficiently. Setting ``torch.set_num_threads(1)`` might help in this #
# case. # When using resampling in multiple subprocesses, such as data loading
# with multiple worker processes, your application might create more
# threads than your system can handle efficiently.
# Setting ``torch.set_num_threads(1)`` might help in this case.
# #
# Because a finite number of samples can only represent a finite number of # Because a finite number of samples can only represent a finite number of
# frequencies, resampling does not produce perfect results, and a variety # frequencies, resampling does not produce perfect results, and a variety
...@@ -230,19 +149,25 @@ def benchmark_resample( ...@@ -230,19 +149,25 @@ def benchmark_resample(
# plotted waveform, and color intensity the amplitude. # plotted waveform, and color intensity the amplitude.
# #
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000
waveform = get_sine_sweep(sample_rate) waveform = get_sine_sweep(sample_rate)
plot_sweep(waveform, sample_rate, title="Original Waveform") plot_sweep(waveform, sample_rate, title="Original Waveform")
play_audio(waveform, sample_rate) Audio(waveform.numpy()[0], rate=sample_rate)
######################################################################
#
# Now we resample (downsample) it.
#
# We see that in the spectrogram of the resampled waveform, there is an
# artifact, which was not present in the original waveform.
resample_rate = 32000
resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype) resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
resampled_waveform = resampler(waveform) resampled_waveform = resampler(waveform)
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
play_audio(waveform, sample_rate)
plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform")
Audio(resampled_waveform.numpy()[0], rate=resample_rate)
###################################################################### ######################################################################
# Controling resampling quality with parameters # Controling resampling quality with parameters
...@@ -260,17 +185,18 @@ play_audio(waveform, sample_rate) ...@@ -260,17 +185,18 @@ play_audio(waveform, sample_rate)
# expensive. # expensive.
# #
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6) resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6") plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6")
######################################################################
#
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128) resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128)
plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128") plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128")
###################################################################### ######################################################################
# Rolloff # Rolloff
# ~~~~~~~ # ~~~~~~~
...@@ -291,6 +217,9 @@ resample_rate = 32000 ...@@ -291,6 +217,9 @@ resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99) resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99") plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99")
######################################################################
#
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8) resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8)
plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8") plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8")
...@@ -314,6 +243,9 @@ resample_rate = 32000 ...@@ -314,6 +243,9 @@ resample_rate = 32000
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation") resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation")
plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default") plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default")
######################################################################
#
resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window") resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window")
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
...@@ -326,11 +258,13 @@ plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") ...@@ -326,11 +258,13 @@ plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default")
# that of librosa (resampy)’s kaiser window resampling, with some noise # that of librosa (resampy)’s kaiser window resampling, with some noise
# #
sample_rate = 48000 sample_rate = 48000
resample_rate = 32000 resample_rate = 32000
######################################################################
# kaiser_best # kaiser_best
# ~~~~~~~~~~~
#
resampled_waveform = F.resample( resampled_waveform = F.resample(
waveform, waveform,
sample_rate, sample_rate,
...@@ -342,15 +276,24 @@ resampled_waveform = F.resample( ...@@ -342,15 +276,24 @@ resampled_waveform = F.resample(
) )
plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)")
######################################################################
#
librosa_resampled_waveform = torch.from_numpy( librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best") librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best")
).unsqueeze(0) ).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)") plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)")
######################################################################
#
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser best MSE:", mse) print("torchaudio and librosa kaiser best MSE:", mse)
######################################################################
# kaiser_fast # kaiser_fast
# ~~~~~~~~~~~
#
resampled_waveform = F.resample( resampled_waveform = F.resample(
waveform, waveform,
sample_rate, sample_rate,
...@@ -360,17 +303,22 @@ resampled_waveform = F.resample( ...@@ -360,17 +303,22 @@ resampled_waveform = F.resample(
resampling_method="kaiser_window", resampling_method="kaiser_window",
beta=8.555504641634386, beta=8.555504641634386,
) )
plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)")
######################################################################
#
librosa_resampled_waveform = torch.from_numpy( librosa_resampled_waveform = torch.from_numpy(
librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast") librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast")
).unsqueeze(0) ).unsqueeze(0)
plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)") plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)")
######################################################################
#
mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item()
print("torchaudio and librosa kaiser fast MSE:", mse) print("torchaudio and librosa kaiser fast MSE:", mse)
###################################################################### ######################################################################
# Performance Benchmarking # Performance Benchmarking
# ------------------------ # ------------------------
...@@ -394,6 +342,57 @@ print("torchaudio and librosa kaiser fast MSE:", mse) ...@@ -394,6 +342,57 @@ print("torchaudio and librosa kaiser fast MSE:", mse)
# #
def benchmark_resample(
method,
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=6,
rolloff=0.99,
resampling_method="sinc_interpolation",
beta=None,
librosa_type=None,
iters=5,
):
if method == "functional":
begin = time.monotonic()
for _ in range(iters):
F.resample(
waveform,
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
)
elapsed = time.monotonic() - begin
return elapsed / iters
elif method == "transforms":
resampler = T.Resample(
sample_rate,
resample_rate,
lowpass_filter_width=lowpass_filter_width,
rolloff=rolloff,
resampling_method=resampling_method,
dtype=waveform.dtype,
)
begin = time.monotonic()
for _ in range(iters):
resampler(waveform)
elapsed = time.monotonic() - begin
return elapsed / iters
elif method == "librosa":
waveform_np = waveform.squeeze().numpy()
begin = time.monotonic()
for _ in range(iters):
librosa.resample(waveform_np, orig_sr=sample_rate, target_sr=resample_rate, res_type=librosa_type)
elapsed = time.monotonic() - begin
return elapsed / iters
######################################################################
#
configs = { configs = {
"downsample (48 -> 44.1 kHz)": [48000, 44100], "downsample (48 -> 44.1 kHz)": [48000, 44100],
"downsample (16 -> 8 kHz)": [16000, 8000], "downsample (16 -> 8 kHz)": [16000, 8000],
...@@ -471,4 +470,7 @@ for label in configs: ...@@ -471,4 +470,7 @@ for label in configs:
df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows) df = pd.DataFrame(times, columns=["librosa", "functional", "transforms"], index=rows)
df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns]) df.columns = pd.MultiIndex.from_product([[f"{label} time (ms)"], df.columns])
print(f"torchaudio: {torchaudio.__version__}")
print(f"librosa: {librosa.__version__}")
display(df.round(2)) display(df.round(2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment